mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 22:41:14 +08:00
feat: 添加少数结构性标签
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
21315a6cb0
commit
abbad781ab
@ -361,7 +361,15 @@ const labelConfig: IAutoLabelConfig = {
|
||||
0,
|
||||
0,
|
||||
0
|
||||
]
|
||||
],
|
||||
symmetry: [
|
||||
info.symmetryH ? 1 : 0,
|
||||
info.symmetryV ? 1 : 0,
|
||||
info.symmetryC ? 1 : 0
|
||||
],
|
||||
outerWall: info.outerWall ? 1 : 0,
|
||||
roomCount: info.roomCount,
|
||||
highDegBranchCount: info.highDegBranchCount
|
||||
};
|
||||
dataset.data[id] = data;
|
||||
});
|
||||
|
||||
@ -3,8 +3,10 @@ import {
|
||||
GraphNodeType,
|
||||
IAutoLabelConfig,
|
||||
IFloorInfo,
|
||||
IMapGraph,
|
||||
IMapTileConverter,
|
||||
ITowerInfo,
|
||||
MapGraphNode,
|
||||
TowerColor
|
||||
} from './types';
|
||||
import {
|
||||
@ -153,6 +155,180 @@ export function computeWallDensityStd(
|
||||
return Math.sqrt(variance);
|
||||
}
|
||||
|
||||
/**
|
||||
* 计算地图的三种对称性(基于 convertedMap,完全匹配才标记为 true)
|
||||
*/
|
||||
function computeSymmetry(map: number[][]): {
|
||||
symmetryH: boolean;
|
||||
symmetryV: boolean;
|
||||
symmetryC: boolean;
|
||||
} {
|
||||
const H = map.length;
|
||||
const W = H > 0 ? map[0].length : 0;
|
||||
let symmetryH = true;
|
||||
let symmetryV = true;
|
||||
let symmetryC = true;
|
||||
|
||||
outer: for (let y = 0; y < H; y++) {
|
||||
for (let x = 0; x < W; x++) {
|
||||
const tile = map[y][x];
|
||||
if (symmetryH && tile !== map[y][W - 1 - x]) symmetryH = false;
|
||||
if (symmetryV && tile !== map[H - 1 - y][x]) symmetryV = false;
|
||||
if (symmetryC && tile !== map[H - 1 - y][W - 1 - x])
|
||||
symmetryC = false;
|
||||
if (!symmetryH && !symmetryV && !symmetryC) break outer;
|
||||
}
|
||||
}
|
||||
return { symmetryH, symmetryV, symmetryC };
|
||||
}
|
||||
|
||||
/**
|
||||
* 检测地图最外圈中墙壁 + 入口的占比是否超过 90%
|
||||
* @param map convertedMap
|
||||
* @param wall 墙壁图块编号
|
||||
* @param entry 入口图块编号
|
||||
*/
|
||||
function computeOuterWall(
|
||||
map: number[][],
|
||||
wall: number,
|
||||
entry: number
|
||||
): boolean {
|
||||
const H = map.length;
|
||||
const W = H > 0 ? map[0].length : 0;
|
||||
if (H < 2 || W < 2) return false;
|
||||
|
||||
let borderCount = 0;
|
||||
let wallOrEntry = 0;
|
||||
|
||||
const check = (tile: number) => {
|
||||
borderCount++;
|
||||
if (tile === wall || tile === entry) wallOrEntry++;
|
||||
};
|
||||
|
||||
for (let x = 0; x < W; x++) {
|
||||
check(map[0][x]);
|
||||
check(map[H - 1][x]);
|
||||
}
|
||||
for (let y = 1; y < H - 1; y++) {
|
||||
check(map[y][0]);
|
||||
check(map[y][W - 1]);
|
||||
}
|
||||
|
||||
return borderCount > 0 && wallOrEntry / borderCount > 0.9;
|
||||
}
|
||||
|
||||
/**
|
||||
* 统计拓扑图中符合"房间"定义的连通区域数量。
|
||||
*
|
||||
* 算法:
|
||||
* 1. 以 Empty / Resource 节点为顶点,在它们之间 BFS,
|
||||
* 得到若干"候选区域"(Branch 节点作为边界,不被合并)。
|
||||
* 2. 对每个候选区域检查三个条件:
|
||||
* a. 区域内至少一个节点有 Branch 类型邻居
|
||||
* b. 区域内所有格子总数 >= 4
|
||||
* c. 所有格子的外接矩形宽 > 1 且高 > 1
|
||||
*
|
||||
* @param graph 拓扑图
|
||||
* @param width 地图宽度(用于平坦坐标解码)
|
||||
*/
|
||||
function computeRoomCount(graph: IMapGraph, width: number): number {
|
||||
const allEmptyResource = new Set<MapGraphNode>();
|
||||
for (const node of graph.nodeMap.values()) {
|
||||
if (
|
||||
node.type === GraphNodeType.Empty ||
|
||||
node.type === GraphNodeType.Resource
|
||||
) {
|
||||
allEmptyResource.add(node);
|
||||
}
|
||||
}
|
||||
|
||||
let roomCount = 0;
|
||||
const visited = new Set<MapGraphNode>();
|
||||
|
||||
for (const startNode of allEmptyResource) {
|
||||
if (visited.has(startNode)) continue;
|
||||
|
||||
const regionNodes = new Set<MapGraphNode>();
|
||||
const queue: MapGraphNode[] = [startNode];
|
||||
visited.add(startNode);
|
||||
|
||||
while (queue.length > 0) {
|
||||
const current = queue.shift()!;
|
||||
regionNodes.add(current);
|
||||
for (const nb of current.neighbors) {
|
||||
if (
|
||||
!visited.has(nb) &&
|
||||
(nb.type === GraphNodeType.Empty ||
|
||||
nb.type === GraphNodeType.Resource)
|
||||
) {
|
||||
visited.add(nb);
|
||||
queue.push(nb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 条件 a:区域内任一节点有 Branch 邻居
|
||||
let hasBranch = false;
|
||||
outer: for (const node of regionNodes) {
|
||||
for (const nb of node.neighbors) {
|
||||
if (nb.type === GraphNodeType.Branch) {
|
||||
hasBranch = true;
|
||||
break outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!hasBranch) continue;
|
||||
|
||||
// 收集区域内所有格子,计算总数和外接矩形
|
||||
let totalTiles = 0;
|
||||
let minX = Infinity,
|
||||
maxX = -Infinity,
|
||||
minY = Infinity,
|
||||
maxY = -Infinity;
|
||||
|
||||
for (const node of regionNodes) {
|
||||
totalTiles += node.tiles.size;
|
||||
for (const t of node.tiles) {
|
||||
const x = t % width;
|
||||
const y = (t - x) / width;
|
||||
if (x < minX) minX = x;
|
||||
if (x > maxX) maxX = x;
|
||||
if (y < minY) minY = y;
|
||||
if (y > maxY) maxY = y;
|
||||
}
|
||||
}
|
||||
|
||||
// 条件 b:总格子数 >= 4
|
||||
if (totalTiles < 4) continue;
|
||||
|
||||
// 条件 c:外接矩形宽高均 > 1
|
||||
if (maxX - minX < 1 || maxY - minY < 1) continue;
|
||||
|
||||
roomCount++;
|
||||
}
|
||||
|
||||
return roomCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* 统计邻居数 >= 3 的分支节点数量(高连接度分支节点)
|
||||
* @param graph 拓扑图
|
||||
*/
|
||||
function computeHighDegBranchCount(graph: IMapGraph): number {
|
||||
let count = 0;
|
||||
const visited = new Set<MapGraphNode>();
|
||||
|
||||
for (const node of graph.nodeMap.values()) {
|
||||
if (visited.has(node)) continue;
|
||||
visited.add(node);
|
||||
|
||||
if (node.type === GraphNodeType.Branch && node.neighbors.size >= 3) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据地图矩阵解析出地图数据
|
||||
* @param tower 地图所属塔信息
|
||||
@ -177,6 +353,17 @@ export function parseFloorInfo(
|
||||
);
|
||||
const flattened = map.flat();
|
||||
const area = flattened.length;
|
||||
const width = map[0]?.length ?? 0;
|
||||
|
||||
// ── 结构标签计算 ─────────────────────────────────
|
||||
const { symmetryH, symmetryV, symmetryC } = computeSymmetry(map);
|
||||
const outerWall = computeOuterWall(
|
||||
map,
|
||||
config.classes.wall,
|
||||
config.classes.entry
|
||||
);
|
||||
const roomCount = computeRoomCount(topo.graph, width);
|
||||
const highDegBranchCount = computeHighDegBranchCount(topo.graph);
|
||||
|
||||
let hasUselessBranch = false;
|
||||
|
||||
@ -283,7 +470,13 @@ export function parseFloorInfo(
|
||||
doorHeatmap: gaussainHeatmap(
|
||||
generateHeatmap(map, doorTiles, config.heatmapKernel),
|
||||
config.guassainRadius
|
||||
)
|
||||
),
|
||||
symmetryH,
|
||||
symmetryV,
|
||||
symmetryC,
|
||||
outerWall,
|
||||
roomCount,
|
||||
highDegBranchCount
|
||||
};
|
||||
|
||||
return floorInfo;
|
||||
|
||||
@ -112,6 +112,20 @@ export interface IFloorInfo {
|
||||
readonly entryHeatmap: number[][];
|
||||
/** 门热力图 */
|
||||
readonly doorHeatmap: number[][];
|
||||
|
||||
// ── 结构标签(新增)──────────────────────────────
|
||||
/** 左右对称(基于 convertedMap 完全匹配) */
|
||||
readonly symmetryH: boolean;
|
||||
/** 上下对称 */
|
||||
readonly symmetryV: boolean;
|
||||
/** 中心对称 */
|
||||
readonly symmetryC: boolean;
|
||||
/** 是否外包围墙壁(最外圈墙壁+入口占比 > 90%) */
|
||||
readonly outerWall: boolean;
|
||||
/** 房间数量原始值(供 Python 两趟扫描使用) */
|
||||
readonly roomCount: number;
|
||||
/** 高连接度分支节点数量原始值(供 Python 两趟扫描使用) */
|
||||
readonly highDegBranchCount: number;
|
||||
}
|
||||
|
||||
export interface IMapBlockConfig {
|
||||
|
||||
@ -48,6 +48,15 @@ export interface GinkaTrainData {
|
||||
map: number[][];
|
||||
size: [number, number];
|
||||
heatmap?: number[][][];
|
||||
// ── 结构标签(新增)──────────────────────────────
|
||||
/** 对称性:[symmetryH, symmetryV, symmetryC],0 或 1 */
|
||||
symmetry: [number, number, number];
|
||||
/** 是否外包围墙壁,0 或 1 */
|
||||
outerWall: number;
|
||||
/** 房间数量原始值 */
|
||||
roomCount: number;
|
||||
/** 高连接度分支节点数量原始值 */
|
||||
highDegBranchCount: number;
|
||||
}
|
||||
|
||||
export interface GinkaDataset {
|
||||
|
||||
248
docs/dataset-labels-design.md
Normal file
248
docs/dataset-labels-design.md
Normal file
@ -0,0 +1,248 @@
|
||||
# 数据集标签设计文档
|
||||
|
||||
## 背景
|
||||
|
||||
在当前 VQ-VAE + MaskGIT 的联合训练方案中,VQ-VAE 的 codebook 承担着地图风格与多样性的控制职能,但缺少可用户感知的语义维度。为提升生成的可控性,计划在数据集中添加一组**可程序化标注的结构标签**,作为额外条件输入训练,从而使模型能够接受来自用户的高层语义约束(如"生成一个左右对称的地图"、"高房间数量"等)。
|
||||
|
||||
所有标签均可在数据集构建阶段(`info.ts` / `parseFloorInfo`)或训练前的两趟扫描中自动计算,无需人工标注。
|
||||
|
||||
---
|
||||
|
||||
## 标签一览
|
||||
|
||||
| 标签名 | 类型 | 取值 | 标注时机 |
|
||||
| ---------------- | --------- | ------------------- | -------- |
|
||||
| `symmetryH` | `boolean` | 左右对称 | 单张地图 |
|
||||
| `symmetryV` | `boolean` | 上下对称 | 单张地图 |
|
||||
| `symmetryC` | `boolean` | 中心对称 | 单张地图 |
|
||||
| `roomCountLevel` | `0\|1\|2` | Low / Medium / High | 两趟扫描 |
|
||||
| `branchLevel` | `0\|1\|2` | Low / Medium / High | 两趟扫描 |
|
||||
| `outerWall` | `boolean` | 是否外包围墙壁 | 单张地图 |
|
||||
|
||||
---
|
||||
|
||||
## 标签一:对称性
|
||||
|
||||
### 定义
|
||||
|
||||
针对经过转换后的地图(`convertedMap`)矩阵逐格比较,仅当**完全满足条件**时才标记为对应对称类型。三种对称相互独立,可同时成立。
|
||||
|
||||
| 对称类型 | 条件(对所有 `x ∈ [0, W)`, `y ∈ [0, H)` 成立) |
|
||||
| -------- | ---------------------------------------------- |
|
||||
| 左右对称 | `map[y][x] === map[y][W - 1 - x]` |
|
||||
| 上下对称 | `map[y][x] === map[H - 1 - y][x]` |
|
||||
| 中心对称 | `map[y][x] === map[H - 1 - y][W - 1 - x]` |
|
||||
|
||||
### 实现要点
|
||||
|
||||
- 比较使用 `convertedMap`(标签化图块编号),而非原始 `originMap`,使不同塔的同类图块具有可比性。
|
||||
- 中心对称与左右、上下对称在数学上不蕴含关系(非充分也非必要),需独立计算。
|
||||
- 对于奇数尺寸的地图(如 13×13),中心行/列与自身比较必然成立,无需特殊处理。
|
||||
|
||||
### 字段扩展(`IFloorInfo`)
|
||||
|
||||
```typescript
|
||||
/** 左右对称 */
|
||||
readonly symmetryH: boolean;
|
||||
/** 上下对称 */
|
||||
readonly symmetryV: boolean;
|
||||
/** 中心对称 */
|
||||
readonly symmetryC: boolean;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 标签二:房间数量等级
|
||||
|
||||
### 定义
|
||||
|
||||
**房间(Room)** 是地图中由"空白节点"或"资源节点"组成的区域,同时满足以下三个条件:
|
||||
|
||||
1. **位置条件**:该节点在拓扑图中至少与 **1 个分支节点(Branch)** 相邻;
|
||||
2. **面积条件**:节点所包含的地图格子数(`tiles.size`)**≥ 4**;
|
||||
3. **形状条件**:节点所有格子的**外接矩形的宽和高都大于 1**(避免把单行/单列走廊计入房间)。
|
||||
|
||||
> **为什么这样定义房间**
|
||||
>
|
||||
> 在拓扑图中,空白/资源节点天然是游戏空间的"腔体",而分支节点(门/怪物)是进入腔体的关卡节点。只要与至少一个分支节点相邻,就说明这片空间是需要"先过关才能进入/离开"的区域——包括怪物守着宝箱这类单入口房间,同样是典型的房间结构。面积和形状约束则过滤掉通道和死胡同。
|
||||
|
||||
### 等级划分
|
||||
|
||||
等级为 **三档**:Low(0)/ Medium(1)/ High(2),通过以下两趟扫描确定:
|
||||
|
||||
1. **第一趟**:遍历整个训练集,计算每张地图的房间数量 `roomCount`,收集为数组;
|
||||
2. **第二趟**:对数组升序排序,取 1/3 和 2/3 分位数作为阈值 `[th1, th2]`:
|
||||
- `roomCount < th1` → Low(0)
|
||||
- `th1 ≤ roomCount < th2` → Medium(1)
|
||||
- `roomCount ≥ th2` → High(2)
|
||||
|
||||
等级划分力求三档样本数量均等(等频分箱),而非等距分箱。
|
||||
|
||||
### 外接矩形计算
|
||||
|
||||
给定节点的所有格子坐标集合(`tiles`,存储 `y * width + x` 的平坦坐标),还原为 `(x, y)` 后:
|
||||
|
||||
$$
|
||||
x_{\min} = \min_{t \in tiles}(t \bmod W), \quad x_{\max} = \max_{t \in tiles}(t \bmod W)
|
||||
$$
|
||||
|
||||
$$
|
||||
y_{\min} = \min_{t \in tiles}(\lfloor t / W \rfloor), \quad y_{\max} = \max_{t \in tiles}(\lfloor t / W \rfloor)
|
||||
$$
|
||||
|
||||
外接矩形宽 = $x_{\max} - x_{\min} + 1$,高 = $y_{\max} - y_{\min} + 1$,两者均需 $> 1$。
|
||||
|
||||
### 字段扩展
|
||||
|
||||
```typescript
|
||||
/** 房间数量(原始统计值,供两趟扫描使用) */
|
||||
readonly roomCount: number;
|
||||
/** 房间数量等级:0=Low, 1=Medium, 2=High(需两趟扫描后赋值) */
|
||||
roomCountLevel: 0 | 1 | 2;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 标签三:分支数量等级
|
||||
|
||||
### 定义
|
||||
|
||||
**高连接度分支节点**:拓扑图中 `type === Branch` 的节点,其**非墙邻居节点**总数(`neighbors.size`)**≥ 3**。由于当前拓扑图中已不含墙节点,`neighbors.size` 即等于非墙邻居数,两者在实现上等价。
|
||||
|
||||
这类节点是地图中的"交叉口"——一个门或怪物后方至少有三条不同路线,是地图分叉度和策略深度的指征。
|
||||
|
||||
等级划分方式与房间数量等级相同:先统计每张地图中高连接度分支节点的数量,再等频分箱为 Low(0)/ Medium(1)/ High(2)。
|
||||
|
||||
### 与房间数量的区别
|
||||
|
||||
| 维度 | 房间数量等级 | 分支数量等级 |
|
||||
| -------- | -------------------------- | ---------------------------- |
|
||||
| 度量对象 | 空白/资源节点区域的封闭性 | 分支节点的路径分叉度 |
|
||||
| 反映特征 | 地图内封闭房间的数量与密度 | 关键路口/多分支门怪的复杂度 |
|
||||
| 典型高值 | 多房间迷宫风格地图 | 高度分叉、策略选择丰富的地图 |
|
||||
|
||||
### 字段扩展
|
||||
|
||||
```typescript
|
||||
/** 高连接度分支节点数量(原始统计值) */
|
||||
readonly highDegBranchCount: number;
|
||||
/** 分支数量等级:0=Low, 1=Medium, 2=High(需两趟扫描后赋值) */
|
||||
branchLevel: 0 | 1 | 2;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 标签四:外包围墙壁
|
||||
|
||||
### 定义
|
||||
|
||||
地图**最外圈**(最外一圈格子)的格子中,**墙壁格子与入口格子**之和占外圈总格子数的比例 > **90%**,则标记为 `outerWall = true`。
|
||||
|
||||
$$
|
||||
\text{outerWall} = \frac{|\{(x,y) \in \text{border} : \text{isWall}(x,y) \lor \text{isEntry}(x,y)\}|}{|\text{border}|} > 0.9
|
||||
$$
|
||||
|
||||
### 最外圈定义
|
||||
|
||||
对于 $H \times W$ 的地图,最外圈为所有满足下列条件之一的格子:
|
||||
|
||||
$$
|
||||
x = 0 \; \lor \; x = W - 1 \; \lor \; y = 0 \; \lor \; y = H - 1
|
||||
$$
|
||||
|
||||
最外圈格子总数为 $2(H + W) - 4$(对于 13×13,共 48 格)。
|
||||
|
||||
### 为什么入口也算"通过"
|
||||
|
||||
入口格子在游戏中是楼梯/传送点,不属于可通行的空地,在视觉和结构上等价于边界开口,属于外圈围合结构的合理组成部分,不应被视为"破坏围合"的元素。
|
||||
|
||||
### 实现要点
|
||||
|
||||
- 使用 `convertedMap` 判断墙壁(`tile === config.wall`);
|
||||
- 使用 `originMap` + `converter.isEntry()` 或直接对 `convertedMap` 判断入口(`tile === config.entry`)判断入口;
|
||||
- 两项合取计入分子。
|
||||
|
||||
### 字段扩展
|
||||
|
||||
```typescript
|
||||
/** 是否外包围墙壁 */
|
||||
readonly outerWall: boolean;
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 实现方案
|
||||
|
||||
### 单张地图可直接计算的标签
|
||||
|
||||
以下标签可在 `parseFloorInfo` 中直接计算,加入 `IFloorInfo`:
|
||||
|
||||
- `symmetryH`、`symmetryV`、`symmetryC`
|
||||
- `outerWall`
|
||||
- `roomCount`(原始值)
|
||||
- `highDegBranchCount`(原始值)
|
||||
|
||||
### 需要两趟扫描的等级标签
|
||||
|
||||
`roomCountLevel` 和 `branchLevel` 依赖全局分位数,须在数据集构建完成后进行二次处理:
|
||||
|
||||
```
|
||||
第一趟:构建所有楼层的 IFloorInfo,写入 roomCount / highDegBranchCount
|
||||
第二趟:收集所有楼层的原始值 → 计算 1/3, 2/3 分位 → 回填等级
|
||||
```
|
||||
|
||||
在 Python 训练侧,推荐方式:
|
||||
|
||||
```python
|
||||
# 在 Dataset.__init__ 中完成两趟计算
|
||||
counts = [item['roomCount'] for item in raw_data]
|
||||
counts_sorted = sorted(counts)
|
||||
th1 = counts_sorted[len(counts_sorted) // 3]
|
||||
th2 = counts_sorted[2 * len(counts_sorted) // 3]
|
||||
|
||||
for item in raw_data:
|
||||
c = item['roomCount']
|
||||
item['roomCountLevel'] = 0 if c < th1 else (1 if c < th2 else 2)
|
||||
```
|
||||
|
||||
同理处理 `branchLevel`。
|
||||
|
||||
> **注意**:分位数阈值应仅基于**训练集**统计,验证集 / 测试集使用相同的阈值映射,避免数据泄露。
|
||||
|
||||
---
|
||||
|
||||
## 训练集成
|
||||
|
||||
### 条件嵌入
|
||||
|
||||
将上述标签作为离散条件与 VQ-VAE 的 z 一同注入 MaskGIT:
|
||||
|
||||
```python
|
||||
# 对称性:三个独立布尔值,可合并为 0~7 的整数 cond_sym
|
||||
cond_sym = symmetryH * 4 + symmetryV * 2 + symmetryC * 1 # [0, 7]
|
||||
|
||||
# 房间等级:0 / 1 / 2
|
||||
cond_room = roomCountLevel
|
||||
|
||||
# 分支等级:0 / 1 / 2
|
||||
cond_branch = branchLevel
|
||||
|
||||
# 外包围墙壁:0 / 1
|
||||
cond_outer = int(outerWall)
|
||||
```
|
||||
|
||||
每个条件通过独立的 `nn.Embedding` 映射为固定维度向量,与 VQ-VAE 的 z 序列沿序列维度拼接后,一同经 Cross-Attention 注入 MaskGIT。
|
||||
|
||||
### 条件 Dropout
|
||||
|
||||
与 z dropout 类似,训练时以一定概率(如 10~20%)将部分或全部结构标签替换为"无条件"(null embedding),使模型在推理时支持条件缺省(CFG 风格)。
|
||||
|
||||
---
|
||||
|
||||
## 待细化事项
|
||||
|
||||
- [x] 90% 阈值是否合适?——**保持 90%**,如后续数据分布分析发现问题再调整。
|
||||
- [x] 房间定义中分支节点邻居数量——**改为至少 1 个**,覆盖怪物守宝箱的单入口房间场景。
|
||||
- [x] 分支等级邻居计数口径——**使用非墙邻居**(当前图结构中无墙节点,与 `neighbors.size` 等价)。
|
||||
- [x] 是否新增通道数量/路径长度标签——**暂不考虑**。
|
||||
- [x] 条件嵌入维度对齐——**各标签 Embedding 与 z 序列拼接后统一经 Cross-Attention 注入**。
|
||||
563
docs/dataset-labels-impl.md
Normal file
563
docs/dataset-labels-impl.md
Normal file
@ -0,0 +1,563 @@
|
||||
# 数据集标签实现指南
|
||||
|
||||
## 总览
|
||||
|
||||
本文档描述将[标签设计文档](./dataset-labels-design.md)中定义的四类结构标签落地到代码中的具体实现步骤,涉及以下文件:
|
||||
|
||||
| 文件 | 变更性质 |
|
||||
| ------------------------ | ------------------------------------------- |
|
||||
| `data/src/auto/types.ts` | 扩展 `IFloorInfo` 接口 |
|
||||
| `data/src/types.ts` | 扩展 `GinkaTrainData` 接口 |
|
||||
| `data/src/auto/info.ts` | 新增四个 helper 函数,修改 `parseFloorInfo` |
|
||||
| `data/src/auto.ts` | 在数据集序列化时写入新字段 |
|
||||
| `ginka/dataset.py` | 两趟扫描 + `__getitem__` 读取新标签 |
|
||||
|
||||
---
|
||||
|
||||
## 第一步:扩展 TypeScript 类型
|
||||
|
||||
### `data/src/auto/types.ts` — 扩展 `IFloorInfo`
|
||||
|
||||
在 `IFloorInfo` 接口的 `doorHeatmap` 字段之后追加以下字段:
|
||||
|
||||
```typescript
|
||||
// ── 结构标签(新增)──────────────────────────────
|
||||
/** 左右对称(基于 convertedMap 完全匹配) */
|
||||
readonly symmetryH: boolean;
|
||||
/** 上下对称 */
|
||||
readonly symmetryV: boolean;
|
||||
/** 中心对称 */
|
||||
readonly symmetryC: boolean;
|
||||
/** 是否外包围墙壁(最外圈墙壁+入口占比 > 90%) */
|
||||
readonly outerWall: boolean;
|
||||
/** 房间数量原始值(供 Python 两趟扫描使用) */
|
||||
readonly roomCount: number;
|
||||
/** 高连接度分支节点数量原始值(供 Python 两趟扫描使用) */
|
||||
readonly highDegBranchCount: number;
|
||||
```
|
||||
|
||||
### `data/src/types.ts` — 扩展 `GinkaTrainData`
|
||||
|
||||
在 `GinkaTrainData` 接口中追加序列化后写入 JSON 的字段:
|
||||
|
||||
```typescript
|
||||
export interface GinkaTrainData {
|
||||
tag?: number[];
|
||||
val: number[];
|
||||
map: number[][];
|
||||
size: [number, number];
|
||||
heatmap?: number[][][];
|
||||
// ── 结构标签(新增)──────────────────────────────
|
||||
/** 对称性:[symmetryH, symmetryV, symmetryC],0 或 1 */
|
||||
symmetry: [number, number, number];
|
||||
/** 是否外包围墙壁,0 或 1 */
|
||||
outerWall: number;
|
||||
/** 房间数量原始值 */
|
||||
roomCount: number;
|
||||
/** 高连接度分支节点数量原始值 */
|
||||
highDegBranchCount: number;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第二步:实现 Helper 函数(`data/src/auto/info.ts`)
|
||||
|
||||
以下四个函数添加到 `computeWallDensityStd` 之后、`parseFloorInfo` 之前。
|
||||
|
||||
### 2.1 对称性计算
|
||||
|
||||
```typescript
|
||||
/**
|
||||
* 计算地图的三种对称性(基于 convertedMap,完全匹配才标记为 true)
|
||||
*/
|
||||
function computeSymmetry(map: number[][]): {
|
||||
symmetryH: boolean;
|
||||
symmetryV: boolean;
|
||||
symmetryC: boolean;
|
||||
} {
|
||||
const H = map.length;
|
||||
const W = H > 0 ? map[0].length : 0;
|
||||
let symmetryH = true;
|
||||
let symmetryV = true;
|
||||
let symmetryC = true;
|
||||
|
||||
outer: for (let y = 0; y < H; y++) {
|
||||
for (let x = 0; x < W; x++) {
|
||||
const tile = map[y][x];
|
||||
if (symmetryH && tile !== map[y][W - 1 - x]) symmetryH = false;
|
||||
if (symmetryV && tile !== map[H - 1 - y][x]) symmetryV = false;
|
||||
if (symmetryC && tile !== map[H - 1 - y][W - 1 - x])
|
||||
symmetryC = false;
|
||||
// 三种对称均已排除,提前退出
|
||||
if (!symmetryH && !symmetryV && !symmetryC) break outer;
|
||||
}
|
||||
}
|
||||
return { symmetryH, symmetryV, symmetryC };
|
||||
}
|
||||
```
|
||||
|
||||
**复杂度**:O(H × W),最坏情况遍历全图一次,实际有短路优化。
|
||||
|
||||
**注意**:对于奇数尺寸地图(如 13×13),中心行/列的格子与自身比较恒为 true,不影响结果。
|
||||
|
||||
### 2.2 外包围墙壁检测
|
||||
|
||||
```typescript
|
||||
/**
|
||||
* 检测地图最外圈中墙壁 + 入口的占比是否超过 90%
|
||||
* @param map convertedMap
|
||||
* @param wall 墙壁图块编号
|
||||
* @param entry 入口图块编号
|
||||
*/
|
||||
function computeOuterWall(
|
||||
map: number[][],
|
||||
wall: number,
|
||||
entry: number
|
||||
): boolean {
|
||||
const H = map.length;
|
||||
const W = H > 0 ? map[0].length : 0;
|
||||
if (H < 2 || W < 2) return false;
|
||||
|
||||
let borderCount = 0;
|
||||
let wallOrEntry = 0;
|
||||
|
||||
const check = (tile: number) => {
|
||||
borderCount++;
|
||||
if (tile === wall || tile === entry) wallOrEntry++;
|
||||
};
|
||||
|
||||
// 顶行 + 底行
|
||||
for (let x = 0; x < W; x++) {
|
||||
check(map[0][x]);
|
||||
check(map[H - 1][x]);
|
||||
}
|
||||
// 左列 + 右列(排除角格,已由上面计入)
|
||||
for (let y = 1; y < H - 1; y++) {
|
||||
check(map[y][0]);
|
||||
check(map[y][W - 1]);
|
||||
}
|
||||
|
||||
// 13×13 地图: borderCount = 2*(13+13)-4 = 48
|
||||
return borderCount > 0 && wallOrEntry / borderCount > 0.9;
|
||||
}
|
||||
```
|
||||
|
||||
### 2.3 房间数量统计
|
||||
|
||||
**核心思路**:拓扑图中 Empty 节点和 Resource 节点在物理上可能彼此相邻,共同构成一个连续的游戏空间(如 2×3 的房间内放了一个宝物)。不能逐节点独立判断,而应先通过 BFS 将相邻的 Empty/Resource 节点合并为一个"候选区域",再对整个区域判断三个条件。Branch 节点作为边界,不会被合并进区域。
|
||||
|
||||
```typescript
|
||||
/**
|
||||
* 统计拓扑图中符合"房间"定义的连通区域数量。
|
||||
*
|
||||
* 算法:
|
||||
* 1. 以 Empty / Resource 节点为顶点,在它们之间 BFS,
|
||||
* 得到若干"候选区域"(Branch 节点作为边界,不被合并)。
|
||||
* 2. 对每个候选区域检查三个条件:
|
||||
* a. 区域内至少一个节点有 Branch 类型邻居
|
||||
* b. 区域内所有格子总数 >= 4
|
||||
* c. 所有格子的外接矩形宽 > 1 且高 > 1
|
||||
*
|
||||
* @param graph 拓扑图
|
||||
* @param width 地图宽度(用于平坦坐标解码)
|
||||
*/
|
||||
function computeRoomCount(graph: IMapGraph, width: number): number {
|
||||
// Step 1: 收集所有 Empty 和 Resource 节点(去重)
|
||||
const allEmptyResource = new Set<MapGraphNode>();
|
||||
for (const node of graph.nodeMap.values()) {
|
||||
if (
|
||||
node.type === GraphNodeType.Empty ||
|
||||
node.type === GraphNodeType.Resource
|
||||
) {
|
||||
allEmptyResource.add(node);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: BFS 将相邻的 Empty/Resource 节点合并为连通区域
|
||||
let roomCount = 0;
|
||||
const visited = new Set<MapGraphNode>();
|
||||
|
||||
for (const startNode of allEmptyResource) {
|
||||
if (visited.has(startNode)) continue;
|
||||
|
||||
// BFS:只在 Empty/Resource 节点间传播,Branch 节点阻断
|
||||
const regionNodes = new Set<MapGraphNode>();
|
||||
const queue: MapGraphNode[] = [startNode];
|
||||
visited.add(startNode);
|
||||
|
||||
while (queue.length > 0) {
|
||||
const current = queue.shift()!;
|
||||
regionNodes.add(current);
|
||||
for (const nb of current.neighbors) {
|
||||
if (
|
||||
!visited.has(nb) &&
|
||||
(nb.type === GraphNodeType.Empty ||
|
||||
nb.type === GraphNodeType.Resource)
|
||||
) {
|
||||
visited.add(nb);
|
||||
queue.push(nb);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: 对合并后的区域检查三个条件
|
||||
|
||||
// 条件 a:区域内任一节点有 Branch 邻居
|
||||
let hasBranch = false;
|
||||
outer: for (const node of regionNodes) {
|
||||
for (const nb of node.neighbors) {
|
||||
if (nb.type === GraphNodeType.Branch) {
|
||||
hasBranch = true;
|
||||
break outer;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!hasBranch) continue;
|
||||
|
||||
// 收集区域内所有格子,计算总数和外接矩形
|
||||
let totalTiles = 0;
|
||||
let minX = Infinity,
|
||||
maxX = -Infinity,
|
||||
minY = Infinity,
|
||||
maxY = -Infinity;
|
||||
|
||||
for (const node of regionNodes) {
|
||||
totalTiles += node.tiles.size;
|
||||
for (const t of node.tiles) {
|
||||
const x = t % width;
|
||||
const y = (t - x) / width;
|
||||
if (x < minX) minX = x;
|
||||
if (x > maxX) maxX = x;
|
||||
if (y < minY) minY = y;
|
||||
if (y > maxY) maxY = y;
|
||||
}
|
||||
}
|
||||
|
||||
// 条件 b:总格子数 >= 4
|
||||
if (totalTiles < 4) continue;
|
||||
|
||||
// 条件 c:外接矩形宽高均 > 1(即 maxX - minX >= 1 && maxY - minY >= 1)
|
||||
if (maxX - minX < 1 || maxY - minY < 1) continue;
|
||||
|
||||
roomCount++;
|
||||
}
|
||||
|
||||
return roomCount;
|
||||
}
|
||||
```
|
||||
|
||||
**边界情况讨论**:
|
||||
|
||||
- **混合节点房间**:2×3 房间(5 个空地 + 1 个资源)→ 1 个 Empty 节点 + 1 个 Resource 节点,BFS 合并后总格子数=6,外接矩形 2×3,**计入房间**。
|
||||
- **纯条形走廊**:1×4 的 Empty 节点 → 外接矩形高=1,**不计入房间**。
|
||||
- **孤立资源死角**:Resource 节点的邻居全是 Entry 而无 Branch → 条件 a 不满足,**不计入房间**。
|
||||
- **两个相邻的 Empty 区域被 Branch 隔开**:BFS 不会跨越 Branch,各自单独判断。
|
||||
|
||||
### 2.4 高连接度分支节点统计
|
||||
|
||||
```typescript
|
||||
/**
|
||||
* 统计邻居数 >= 3 的分支节点数量
|
||||
*
|
||||
* 由于拓扑图中已不含墙节点,neighbors.size 即等于非墙邻居数。
|
||||
*
|
||||
* @param graph 拓扑图
|
||||
*/
|
||||
function computeHighDegBranchCount(graph: IMapGraph): number {
|
||||
let count = 0;
|
||||
const visited = new Set<MapGraphNode>();
|
||||
|
||||
for (const node of graph.nodeMap.values()) {
|
||||
if (visited.has(node)) continue;
|
||||
visited.add(node);
|
||||
|
||||
if (node.type === GraphNodeType.Branch && node.neighbors.size >= 3) {
|
||||
count++;
|
||||
}
|
||||
}
|
||||
return count;
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第三步:在 `parseFloorInfo` 中调用
|
||||
|
||||
在 `parseFloorInfo` 函数内,`topo` 构建完毕、`floorInfo` 对象构造之前,调用上述四个函数:
|
||||
|
||||
```typescript
|
||||
// ── 结构标签计算 ─────────────────────────────────
|
||||
const width = map[0]?.length ?? 0;
|
||||
const { symmetryH, symmetryV, symmetryC } = computeSymmetry(map);
|
||||
const outerWall = computeOuterWall(
|
||||
map,
|
||||
config.classes.wall,
|
||||
config.classes.entry
|
||||
);
|
||||
const roomCount = computeRoomCount(topo.graph, width);
|
||||
const highDegBranchCount = computeHighDegBranchCount(topo.graph);
|
||||
```
|
||||
|
||||
然后在 `floorInfo` 字面量中追加这些字段:
|
||||
|
||||
```typescript
|
||||
const floorInfo: IFloorInfo = {
|
||||
// ...(原有字段保持不变)
|
||||
symmetryH,
|
||||
symmetryV,
|
||||
symmetryC,
|
||||
outerWall,
|
||||
roomCount,
|
||||
highDegBranchCount
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第四步:在 `data/src/auto.ts` 中序列化新字段
|
||||
|
||||
在构建 `GinkaTrainData` 的对象字面量中追加:
|
||||
|
||||
```typescript
|
||||
const data: GinkaTrainData = {
|
||||
map: floor.data.map,
|
||||
size: [width, height],
|
||||
heatmap: [
|
||||
/* ...原有热力图通道... */
|
||||
],
|
||||
val: [
|
||||
/* ...原有标量... */
|
||||
],
|
||||
// ── 新增结构标签 ──────────────────────────────
|
||||
symmetry: [
|
||||
info.symmetryH ? 1 : 0,
|
||||
info.symmetryV ? 1 : 0,
|
||||
info.symmetryC ? 1 : 0
|
||||
],
|
||||
outerWall: info.outerWall ? 1 : 0,
|
||||
roomCount: info.roomCount,
|
||||
highDegBranchCount: info.highDegBranchCount
|
||||
};
|
||||
```
|
||||
|
||||
布尔值存为 `0/1` 整数,便于 Python 侧直接读取,不需要类型转换。
|
||||
|
||||
---
|
||||
|
||||
## 第五步:Python 侧两趟扫描(`ginka/dataset.py`)
|
||||
|
||||
### 5.1 等频分箱函数
|
||||
|
||||
在 `dataset.py` 顶部添加通用的等频分箱 helper:
|
||||
|
||||
```python
|
||||
def assign_level(values: list[int]) -> list[int]:
|
||||
"""
|
||||
将整数列表按等频分箱映射为 0/1/2 三档等级。
|
||||
分位数阈值基于当前列表计算(训练集与验证集应共用训练集的阈值)。
|
||||
|
||||
Args:
|
||||
values: 原始统计值列表,顺序与数据集一一对应
|
||||
|
||||
Returns:
|
||||
与输入等长的等级列表,每项为 0 / 1 / 2
|
||||
"""
|
||||
n = len(values)
|
||||
if n == 0:
|
||||
return []
|
||||
sorted_vals = sorted(values)
|
||||
th1 = sorted_vals[n // 3]
|
||||
th2 = sorted_vals[2 * n // 3]
|
||||
return [
|
||||
0 if v < th1 else (1 if v < th2 else 2)
|
||||
for v in values
|
||||
]
|
||||
```
|
||||
|
||||
### 5.2 修改 `GinkaVQDataset.__init__`
|
||||
|
||||
在加载数据之后立即进行两趟扫描,并将等级结果回填到各 item 中:
|
||||
|
||||
```python
|
||||
class GinkaVQDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_path: str,
|
||||
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
|
||||
wall_mask_min: float = 0.0,
|
||||
wall_mask_max: float = 0.5,
|
||||
# 以下两个参数仅在 train 模式下使用,验证集传入训练集的阈值即可
|
||||
room_thresholds: tuple[int, int] | None = None,
|
||||
branch_thresholds: tuple[int, int] | None = None,
|
||||
):
|
||||
self.data = load_data(data_path)
|
||||
# ...(原有初始化逻辑)
|
||||
|
||||
# ── 两趟扫描:计算等频分箱阈值 ──────────────────────────────
|
||||
room_counts = [item['roomCount'] for item in self.data]
|
||||
branch_counts = [item['highDegBranchCount'] for item in self.data]
|
||||
|
||||
if room_thresholds is None:
|
||||
# 训练集:自行计算阈值
|
||||
n = len(room_counts)
|
||||
rs = sorted(room_counts)
|
||||
bs = sorted(branch_counts)
|
||||
self.room_th = (rs[n // 3], rs[2 * n // 3])
|
||||
self.branch_th = (bs[n // 3], bs[2 * n // 3])
|
||||
else:
|
||||
# 验证集/测试集:直接使用训练集的阈值,避免数据泄露
|
||||
self.room_th = room_thresholds
|
||||
self.branch_th = branch_thresholds
|
||||
|
||||
def to_level(v: int, th: tuple[int, int]) -> int:
|
||||
return 0 if v < th[0] else (1 if v < th[1] else 2)
|
||||
|
||||
# 回填等级字段
|
||||
for item in self.data:
|
||||
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
|
||||
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
|
||||
```
|
||||
|
||||
**调用方式**(`train_vq.py` 中):
|
||||
|
||||
```python
|
||||
dataset_train = GinkaVQDataset(args.train)
|
||||
dataset_val = GinkaVQDataset(
|
||||
args.validate,
|
||||
room_thresholds=dataset_train.room_th,
|
||||
branch_thresholds=dataset_train.branch_th
|
||||
)
|
||||
```
|
||||
|
||||
### 5.3 `__getitem__` 中读取结构标签
|
||||
|
||||
对称性标签在数据增强(旋转/翻转)后需要**重新从增强后的地图中计算**,因为 `rot90(k=1/3)` 会交换 `symmetryH` 和 `symmetryV`。其他三个标签在旋转/翻转下保持不变,可直接读取。
|
||||
|
||||
```python
|
||||
def _compute_symmetry(target_np: np.ndarray) -> tuple[int, int, int]:
|
||||
"""从 numpy 地图矩阵中直接计算三种对称性,O(H*W)"""
|
||||
H, W = target_np.shape
|
||||
sym_h = bool(np.all(target_np == target_np[:, ::-1]))
|
||||
sym_v = bool(np.all(target_np == target_np[::-1, :]))
|
||||
sym_c = bool(np.all(target_np == target_np[::-1, ::-1]))
|
||||
return int(sym_h), int(sym_v), int(sym_c)
|
||||
```
|
||||
|
||||
在 `__getitem__` 数据增强完成后,读取所有标签:
|
||||
|
||||
```python
|
||||
def __getitem__(self, idx):
|
||||
# ...(原有增强逻辑,target_np 已经过 rot90 / flip)
|
||||
|
||||
# 对称性:在增强后重新计算
|
||||
sym_h, sym_v, sym_c = _compute_symmetry(target_np)
|
||||
cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7]
|
||||
|
||||
# 其余标签:增强不改变拓扑结构,直接读取
|
||||
item = self.data[idx]
|
||||
cond_room = item['roomCountLevel'] # 0/1/2
|
||||
cond_branch = item['branchLevel'] # 0/1/2
|
||||
cond_outer = item['outerWall'] # 0/1
|
||||
|
||||
# 封装为 tensor
|
||||
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
|
||||
|
||||
return {
|
||||
"raw_map": ...,
|
||||
"masked_map": ...,
|
||||
"target_map": ...,
|
||||
"subset": ...,
|
||||
"struct_cond": struct_cond # [4],供模型 Embedding 查表
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第六步:模型侧条件注入(`ginka/maskGIT/model.py`)
|
||||
|
||||
`struct_cond` 的四个维度分别对应不同词表大小的 Embedding:
|
||||
|
||||
```python
|
||||
# 词表大小
|
||||
SYM_VOCAB = 8 # cond_sym: [0, 7]
|
||||
ROOM_VOCAB = 4 # cond_room: [0, 2] + 1 个 null(dropout 用)
|
||||
BRANCH_VOCAB = 4 # cond_branch: [0, 2] + 1 个 null
|
||||
OUTER_VOCAB = 3 # cond_outer: [0, 1] + 1 个 null
|
||||
|
||||
# 在 GinkaMaskGIT.__init__ 中
|
||||
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
|
||||
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
|
||||
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
|
||||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||
```
|
||||
|
||||
在 `forward` 中,将四个 Embedding 结果与 VQ-VAE 的 z 序列拼接,作为 Cross-Attention 的 memory:
|
||||
|
||||
```python
|
||||
def forward(self, map_tokens, z, struct_cond, dropout_struct=False):
|
||||
# z: [B, L, d_z]
|
||||
# struct_cond: [B, 4] — [sym, room, branch, outer]
|
||||
|
||||
B = z.size(0)
|
||||
|
||||
if dropout_struct or (self.training and torch.rand(1) < self.struct_dropout_prob):
|
||||
# 条件 dropout:全部替换为 null index(各词表最后一个 index)
|
||||
e_sym = self.sym_embed(torch.full((B,), SYM_VOCAB - 1, device=z.device))
|
||||
e_room = self.room_embed(torch.full((B,), ROOM_VOCAB - 1, device=z.device))
|
||||
e_branch = self.branch_embed(torch.full((B,), BRANCH_VOCAB - 1, device=z.device))
|
||||
e_outer = self.outer_embed(torch.full((B,), OUTER_VOCAB - 1, device=z.device))
|
||||
else:
|
||||
e_sym = self.sym_embed(struct_cond[:, 0]) # [B, d_z]
|
||||
e_room = self.room_embed(struct_cond[:, 1])
|
||||
e_branch = self.branch_embed(struct_cond[:, 2])
|
||||
e_outer = self.outer_embed(struct_cond[:, 3])
|
||||
|
||||
# 将四个结构标签嵌入拼接为序列,与 z 合并
|
||||
# 每个 e_* 形状为 [B, d_z],unsqueeze 后变为 [B, 1, d_z]
|
||||
struct_seq = torch.stack(
|
||||
[e_sym, e_room, e_branch, e_outer], dim=1
|
||||
) # [B, 4, d_z]
|
||||
|
||||
memory = torch.cat([z, struct_seq], dim=1) # [B, L+4, d_z]
|
||||
|
||||
# 后续 Cross-Attention 正常进行
|
||||
# query = map_token_embeddings,key/value = memory
|
||||
...
|
||||
```
|
||||
|
||||
**null index 规则**:各 Embedding 的最后一个 index 保留为"无条件"占位符,词表大小因此比有效类别数多 1。
|
||||
|
||||
---
|
||||
|
||||
## 边界情况与注意事项
|
||||
|
||||
### 关于两趟扫描的时机
|
||||
|
||||
两趟扫描在 `Dataset.__init__` 中完成,整个过程仅需遍历两次 Python 列表,耗时可忽略不计(数千条数据 < 1ms)。不建议延迟到 `__getitem__` 中逐条计算。
|
||||
|
||||
### 关于对称性在数据增强下的变化
|
||||
|
||||
| 增强操作 | `symmetryH` | `symmetryV` | `symmetryC` |
|
||||
| ----------------- | ----------- | ----------- | ----------- |
|
||||
| `fliplr` | 不变 | 不变 | 不变 |
|
||||
| `flipud` | 不变 | 不变 | 不变 |
|
||||
| `rot90(k=1 or 3)` | 与 V 交换 | 与 H 交换 | 不变 |
|
||||
| `rot90(k=2)` | 不变 | 不变 | 不变 |
|
||||
|
||||
因此在 `__getitem__` 中对增强后的地图**重新计算**对称性是最简洁正确的方案,无需记录增强历史。
|
||||
|
||||
### 关于极端分布
|
||||
|
||||
若训练集中某一标签的样本极度不均匀(如 90% 的地图无对称性),可以在条件 Dropout 中对不常见的条件值适当降低 dropout 概率,以确保模型充分学习该条件。初始阶段统一使用相同 dropout 概率即可,后续根据生成效果调整。
|
||||
|
||||
### 关于 `roomCount = 0` 和 `highDegBranchCount = 0` 的地图
|
||||
|
||||
这类地图在等频分箱后会进入 Low(0)等级。如果训练集中大量地图的值为 0,`th1` 可能也为 0,导致 Low 等级极少。可以在分箱前加一步检查:若 `th1 == th2`,则手动将 `th2 = th1 + 1` 以避免 Medium 等级为空。
|
||||
|
||||
```python
|
||||
th1 = sorted_vals[n // 3]
|
||||
th2 = sorted_vals[2 * n // 3]
|
||||
if th1 == th2:
|
||||
th2 = th1 + 1 # 防止 Medium 等级为空
|
||||
```
|
||||
@ -376,7 +376,7 @@ $$\mathcal{L} = \mathcal{L}_{CE}(\text{MaskGIT}) + \beta \cdot \mathcal{L}_{comm
|
||||
|
||||
## 待探索事项
|
||||
|
||||
- 合适的 K、L 取值(建议从 K=16, L=2 开始实验)
|
||||
- 合适的 K、L 取值(建议从 K=16, L=2 开始实验),K=1, L=64 可能比较合适。
|
||||
- z dropout 的最优概率
|
||||
- 若后续 codebook collapse:引入 EMA 更新 + 重置机制
|
||||
- 若后续需要更细粒度控制:加入标量 cond(需对推理侧标量做随机扰动处理)
|
||||
|
||||
@ -232,6 +232,14 @@ class GinkaJointDataset(Dataset):
|
||||
}
|
||||
|
||||
|
||||
def _compute_symmetry(target_np: np.ndarray) -> tuple:
|
||||
"""从 numpy 地图矩阵中直接计算三种对称性,O(H*W)"""
|
||||
sym_h = bool(np.all(target_np == target_np[:, ::-1]))
|
||||
sym_v = bool(np.all(target_np == target_np[::-1, :]))
|
||||
sym_c = bool(np.all(target_np == target_np[::-1, ::-1]))
|
||||
return int(sym_h), int(sym_v), int(sym_c)
|
||||
|
||||
|
||||
class GinkaVQDataset(Dataset):
|
||||
"""
|
||||
用于 VQ-VAE + MaskGIT 联合训练的多子集数据集。
|
||||
@ -259,13 +267,17 @@ class GinkaVQDataset(Dataset):
|
||||
data_path: str,
|
||||
subset_weights: tuple = (0.5, 0.2, 0.2, 0.1),
|
||||
wall_mask_ratio: float = 0.3,
|
||||
room_thresholds: tuple = None,
|
||||
branch_thresholds: tuple = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
data_path: JSON 数据文件路径
|
||||
subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化
|
||||
wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限
|
||||
(每次从 [0, wall_mask_ratio] 均匀采样实际比例)
|
||||
data_path: JSON 数据文件路径
|
||||
subset_weights: 子集 (A, B, C, D) 的采样权重,自动归一化
|
||||
wall_mask_ratio: Subset C 中额外随机 mask 的 wall tile 比例上限
|
||||
(每次从 [0, wall_mask_ratio] 均匀采样实际比例)
|
||||
room_thresholds: (th1, th2) 房间数量等频分箱阈值;为 None 时自动从当前数据计算(训练集)
|
||||
branch_thresholds: (th1, th2) 分支数量等频分箱阈值;为 None 时自动从当前数据计算(训练集)
|
||||
"""
|
||||
self.data = load_data(data_path)
|
||||
self.wall_mask_ratio = wall_mask_ratio
|
||||
@ -275,6 +287,35 @@ class GinkaVQDataset(Dataset):
|
||||
normalized = [x / total_w for x in subset_weights]
|
||||
self.subset_cumw = [sum(normalized[:i + 1]) for i in range(len(normalized))]
|
||||
|
||||
# ── 两趟扫描:计算等频分箱阈值 ──────────────────────────────
|
||||
room_counts = [item['roomCount'] for item in self.data]
|
||||
branch_counts = [item['highDegBranchCount'] for item in self.data]
|
||||
|
||||
if room_thresholds is None:
|
||||
n = len(room_counts)
|
||||
rs = sorted(room_counts)
|
||||
bs = sorted(branch_counts)
|
||||
th1_r, th2_r = rs[n // 3], rs[2 * n // 3]
|
||||
th1_b, th2_b = bs[n // 3], bs[2 * n // 3]
|
||||
# 防止 Medium 等级为空
|
||||
if th1_r == th2_r:
|
||||
th2_r = th1_r + 1
|
||||
if th1_b == th2_b:
|
||||
th2_b = th1_b + 1
|
||||
self.room_th = (th1_r, th2_r)
|
||||
self.branch_th = (th1_b, th2_b)
|
||||
else:
|
||||
self.room_th = room_thresholds
|
||||
self.branch_th = branch_thresholds
|
||||
|
||||
def to_level(v: int, th: tuple) -> int:
|
||||
return 0 if v < th[0] else (1 if v < th[1] else 2)
|
||||
|
||||
# 回填等级字段
|
||||
for item in self.data:
|
||||
item['roomCountLevel'] = to_level(item['roomCount'], self.room_th)
|
||||
item['branchLevel'] = to_level(item['highDegBranchCount'], self.branch_th)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
@ -407,11 +448,23 @@ class GinkaVQDataset(Dataset):
|
||||
masked_np = self._apply_subset(raw_np, subset) # [H*W]
|
||||
raw_flat = raw_np.reshape(-1) # [H*W]
|
||||
|
||||
# 对称性:在增强后重新计算
|
||||
sym_h, sym_v, sym_c = _compute_symmetry(raw_np)
|
||||
cond_sym = sym_h * 4 + sym_v * 2 + sym_c # [0, 7]
|
||||
|
||||
# 其余结构标签:增强不改变拓扑结构,直接读取
|
||||
cond_room = item['roomCountLevel'] # 0/1/2
|
||||
cond_branch = item['branchLevel'] # 0/1/2
|
||||
cond_outer = item['outerWall'] # 0/1
|
||||
|
||||
struct_cond = torch.LongTensor([cond_sym, cond_room, cond_branch, cond_outer])
|
||||
|
||||
return {
|
||||
"raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入
|
||||
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
|
||||
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
|
||||
"subset": subset, # 调试/统计用
|
||||
"raw_map": torch.LongTensor(raw_flat), # VQ-VAE 编码器输入
|
||||
"masked_map": torch.LongTensor(masked_np), # MaskGIT 输入
|
||||
"target_map": torch.LongTensor(raw_flat.copy()), # CE loss ground truth
|
||||
"subset": subset, # 调试/统计用
|
||||
"struct_cond": struct_cond, # [4],供模型 Embedding 查表
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -4,6 +4,12 @@ import torch.nn as nn
|
||||
from ..utils import print_memory
|
||||
from .maskGIT import Transformer
|
||||
|
||||
# 结构标签词表大小(最后一个索引为无条件占位符 null)
|
||||
SYM_VOCAB = 8 # symmetryH/V/C 三位组合 0-6,7 = null
|
||||
ROOM_VOCAB = 4 # roomCountLevel 0-2,3 = null
|
||||
BRANCH_VOCAB = 4 # branchLevel 0-2,3 = null
|
||||
OUTER_VOCAB = 3 # outerWall 0-1,2 = null
|
||||
|
||||
|
||||
class GinkaMaskGIT(nn.Module):
|
||||
"""
|
||||
@ -26,20 +32,24 @@ class GinkaMaskGIT(nn.Module):
|
||||
num_layers: int = 4,
|
||||
map_size: int = 13 * 13,
|
||||
z_dropout: float = 0.1,
|
||||
struct_dropout: float = 0.15,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
num_classes: tile 类别数(含 MASK token=15)
|
||||
d_model: Transformer 内部维度
|
||||
d_z: VQ-VAE 码字嵌入维度,需与 GinkaVQVAE.d_z 一致
|
||||
dim_ff: 前馈网络隐层维度
|
||||
nhead: 注意力头数
|
||||
num_layers: Transformer 层数
|
||||
map_size: 地图 token 总数(H * W)
|
||||
z_dropout: 训练时随机替换 z 为随机码字的概率(提升鲁棒性)
|
||||
num_classes: tile 类别数(含 MASK token=15)
|
||||
d_model: Transformer 内部维度
|
||||
d_z: VQ-VAE 码字嵌入维度,需与 GinkaVQVAE.d_z 一致
|
||||
dim_ff: 前馈网络隐层维度
|
||||
nhead: 注意力头数
|
||||
num_layers: Transformer 层数
|
||||
map_size: 地图 token 总数(H * W)
|
||||
z_dropout: 训练时随机替换 z 为随机码字的概率(提升鲁棒性)
|
||||
struct_dropout: 训练时以此概率将结构标签替换为 null(无条件占位),
|
||||
实现 classifier-free guidance 兼容训练
|
||||
"""
|
||||
super().__init__()
|
||||
self.z_dropout = z_dropout
|
||||
self.struct_dropout_prob = struct_dropout
|
||||
|
||||
# Tile 嵌入 + 位置编码
|
||||
self.tile_embedding = nn.Embedding(num_classes, d_model)
|
||||
@ -51,6 +61,12 @@ class GinkaMaskGIT(nn.Module):
|
||||
nn.LayerNorm(d_model),
|
||||
)
|
||||
|
||||
# 结构标签嵌入(编码到 d_z 维度,与 z 拼接后统一投影到 d_model)
|
||||
self.sym_embed = nn.Embedding(SYM_VOCAB, d_z)
|
||||
self.room_embed = nn.Embedding(ROOM_VOCAB, d_z)
|
||||
self.branch_embed = nn.Embedding(BRANCH_VOCAB, d_z)
|
||||
self.outer_embed = nn.Embedding(OUTER_VOCAB, d_z)
|
||||
|
||||
# Transformer:encoder 做 map token 自注意力,decoder 做与 z 的 cross-attention
|
||||
self.transformer = Transformer(
|
||||
d_model=d_model, dim_ff=dim_ff, nhead=nhead, num_layers=num_layers
|
||||
@ -58,33 +74,75 @@ class GinkaMaskGIT(nn.Module):
|
||||
|
||||
self.output_fc = nn.Linear(d_model, num_classes)
|
||||
|
||||
def forward(self, map: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
map: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
struct_cond: torch.Tensor | None = None,
|
||||
dropout_struct: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15)
|
||||
z: [B, L, d_z] VQ-VAE 量化后的离散隐变量
|
||||
map: [B, H*W] 掩码后的地图 token 序列(MASK token = 15)
|
||||
z: [B, L, d_z] VQ-VAE 量化后的离散隐变量
|
||||
struct_cond: [B, 4] 结构标签 LongTensor,顺序为
|
||||
[cond_sym, cond_room, cond_branch, cond_outer];
|
||||
为 None 时等价于全 null(无条件模式)
|
||||
dropout_struct: bool 强制将所有结构标签替换为 null(推理时无条件生成)
|
||||
|
||||
Returns:
|
||||
logits: [B, H*W, num_classes]
|
||||
"""
|
||||
B = z.shape[0]
|
||||
|
||||
# z dropout:训练时以一定概率将 z 替换为随机均匀噪声,
|
||||
# 模拟推理时随机采样 z 的分布,避免模型过拟合于精确的 z 语义
|
||||
if self.training and self.z_dropout > 0:
|
||||
mask = torch.rand(z.shape[0], 1, 1, device=z.device) < self.z_dropout
|
||||
mask = torch.rand(B, 1, 1, device=z.device) < self.z_dropout
|
||||
rand_z = torch.randn_like(z)
|
||||
z = torch.where(mask, rand_z, z)
|
||||
|
||||
# 投影 z 到 d_model 维度
|
||||
z_mem = self.z_proj(z) # [B, L, d_model]
|
||||
# 结构标签嵌入
|
||||
# struct_cond 为 None 或 dropout_struct=True 时,全部使用 null 索引
|
||||
if struct_cond is None or dropout_struct:
|
||||
sym_idx = torch.full((B,), SYM_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||
room_idx = torch.full((B,), ROOM_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||
branch_idx = torch.full((B,), BRANCH_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||
outer_idx = torch.full((B,), OUTER_VOCAB - 1, dtype=torch.long, device=z.device)
|
||||
else:
|
||||
sc = struct_cond.to(z.device)
|
||||
sym_idx, room_idx, branch_idx, outer_idx = sc[:, 0], sc[:, 1], sc[:, 2], sc[:, 3]
|
||||
|
||||
# 训练时对各标签独立做 struct dropout
|
||||
if self.training and self.struct_dropout_prob > 0:
|
||||
def _drop(idx, null_val):
|
||||
drop_mask = torch.rand(B, device=z.device) < self.struct_dropout_prob
|
||||
return torch.where(drop_mask, torch.full_like(idx, null_val), idx)
|
||||
sym_idx = _drop(sym_idx, SYM_VOCAB - 1)
|
||||
room_idx = _drop(room_idx, ROOM_VOCAB - 1)
|
||||
branch_idx = _drop(branch_idx, BRANCH_VOCAB - 1)
|
||||
outer_idx = _drop(outer_idx, OUTER_VOCAB - 1)
|
||||
|
||||
# 嵌入结构标签到 d_z 维度,拼接到 z 序列末尾
|
||||
e_sym = self.sym_embed(sym_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_room = self.room_embed(room_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_branch = self.branch_embed(branch_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
e_outer = self.outer_embed(outer_idx).unsqueeze(1) # [B, 1, d_z]
|
||||
|
||||
struct_seq = torch.cat([e_sym, e_room, e_branch, e_outer], dim=1) # [B, 4, d_z]
|
||||
z_ext = torch.cat([z, struct_seq], dim=1) # [B, L+4, d_z]
|
||||
|
||||
# 统一投影到 d_model 维度
|
||||
z_mem = self.z_proj(z_ext) # [B, L+4, d_model]
|
||||
|
||||
# tile embedding + 位置编码
|
||||
x = self.tile_embedding(map) # [B, H*W, d_model]
|
||||
x = x + self.pos_embedding # [B, H*W, d_model]
|
||||
x = self.tile_embedding(map) # [B, H*W, d_model]
|
||||
x = x + self.pos_embedding # [B, H*W, d_model]
|
||||
|
||||
# Transformer:encoder 做 map 自注意力,decoder cross-attend z
|
||||
# Transformer:encoder 做 map 自注意力,decoder cross-attend z+struct
|
||||
x = self.transformer(x, memory=z_mem) # [B, H*W, d_model]
|
||||
|
||||
logits = self.output_fc(x) # [B, H*W, num_classes]
|
||||
logits = self.output_fc(x) # [B, H*W, num_classes]
|
||||
return logits
|
||||
|
||||
|
||||
@ -107,11 +165,19 @@ if __name__ == "__main__":
|
||||
n = sum(p.numel() for p in module.parameters())
|
||||
print(f" {name}: {n:,}")
|
||||
|
||||
map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169]
|
||||
z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64]
|
||||
map_input = torch.randint(0, 16, (4, 13 * 13)).to(device) # [B=4, 169]
|
||||
z_input = torch.randn(4, 2, 64).to(device) # [B=4, L=2, d_z=64]
|
||||
struct_input = torch.tensor([[3, 1, 0, 1],
|
||||
[0, 2, 1, 0],
|
||||
[7, 3, 3, 2],
|
||||
[1, 0, 2, 1]], dtype=torch.long).to(device) # [B=4, 4]
|
||||
|
||||
model.train()
|
||||
logits = model(map_input, z_input)
|
||||
logits = model(map_input, z_input, struct_cond=struct_input)
|
||||
print(f"\nlogits shape: {logits.shape}") # [4, 169, 16]
|
||||
|
||||
# 无条件模式测试
|
||||
logits_uncond = model(map_input, z_input, struct_cond=None)
|
||||
print(f"logits_uncond shape: {logits_uncond.shape}") # [4, 169, 16]
|
||||
|
||||
print_memory(device, "前向传播后")
|
||||
|
||||
@ -59,7 +59,8 @@ MG_D_MODEL = 192
|
||||
MG_NHEAD = 8
|
||||
MG_LAYERS = 4
|
||||
MG_DIM_FF = 512
|
||||
MG_Z_DROPOUT= 0.15 # 训练时以此概率把 z 替换为随机噪声
|
||||
MG_Z_DROPOUT = 0.15 # 训练时以此概率把 z 替换为随机噪声
|
||||
MG_STRUCT_DROPOUT= 0.15 # 训练时以此概率将结构标签替换为 null(无条件占位)
|
||||
|
||||
# 验证时对每条样本额外采样的 z 数量(0 = 只用真实 z)
|
||||
N_Z_SAMPLES = 3
|
||||
@ -106,6 +107,7 @@ def maskgit_generate(
|
||||
z: torch.Tensor,
|
||||
steps: int = GENERATE_STEP,
|
||||
init_map: torch.Tensor = None,
|
||||
struct_cond: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
迭代生成地图(cosine schedule unmasking)。
|
||||
@ -133,7 +135,7 @@ def maskgit_generate(
|
||||
if not generatable.any():
|
||||
break
|
||||
|
||||
logits = model_mg(map_seq, z) # [B, S, C]
|
||||
logits = model_mg(map_seq, z, struct_cond=struct_cond) # [B, S, C]
|
||||
probs = F.softmax(logits, dim=-1)
|
||||
dist = torch.distributions.Categorical(probs)
|
||||
sampled = dist.sample() # [B, S]
|
||||
@ -279,7 +281,8 @@ def validate(
|
||||
B = raw_map.shape[0]
|
||||
|
||||
z_q, _, vq_loss, _, _ = model_vq(raw_map)
|
||||
logits = model_mg(masked_map, z_q)
|
||||
struct_cond_b = batch["struct_cond"].to(device) # [B, 4]
|
||||
logits = model_mg(masked_map, z_q, struct_cond=struct_cond_b)
|
||||
mask = (masked_map == MASK_TOKEN)
|
||||
|
||||
ce_loss = F.cross_entropy(
|
||||
@ -294,9 +297,10 @@ def validate(
|
||||
s = subsets[i]
|
||||
if captured[s] is None:
|
||||
captured[s] = {
|
||||
"raw": raw_map[i:i+1].clone(),
|
||||
"masked": masked_map[i:i+1].clone(),
|
||||
"z_q": z_q[i:i+1].clone(),
|
||||
"raw": raw_map[i:i+1].clone(),
|
||||
"masked": masked_map[i:i+1].clone(),
|
||||
"z_q": z_q[i:i+1].clone(),
|
||||
"struct_cond": struct_cond_b[i:i+1].clone(),
|
||||
}
|
||||
|
||||
if all(v is not None for v in captured.values()):
|
||||
@ -307,24 +311,24 @@ def validate(
|
||||
imgs = []
|
||||
for i in range(n):
|
||||
z_r = model_vq.sample(1, device)
|
||||
gen = maskgit_generate(model_mg, z_r, init_map=cond_map)
|
||||
gen = maskgit_generate(model_mg, z_r, init_map=cond_map) # struct_cond=None 无条件
|
||||
imgs.append(label_image(make_map_image(gen[0], tile_dict), f"z_rand_{i + 1}"))
|
||||
return imgs
|
||||
|
||||
# ── 场景1:标准掩码补全(子集 A)─────────────────────────────────────────
|
||||
if captured['A'] is not None:
|
||||
cap = captured['A']
|
||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
||||
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "masked input")
|
||||
|
||||
# 单步 argmax 预测(观察模型对掩码位置的瞬时判断)
|
||||
pred = model_mg(cond, z_q).argmax(dim=-1)[0]
|
||||
pred = model_mg(cond, z_q, struct_cond=sc).argmax(dim=-1)[0]
|
||||
pred_img = label_image(make_map_image(pred, tile_dict), "z_real pred")
|
||||
|
||||
# 迭代生成(从掩码输入出发,真实 z)
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, pred_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
@ -333,11 +337,11 @@ def validate(
|
||||
# ── 场景2:墙壁辅助生成(子集 B)─────────────────────────────────────────
|
||||
if captured['B'] is not None:
|
||||
cap = captured['B']
|
||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
||||
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall-only input")
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
@ -346,11 +350,11 @@ def validate(
|
||||
# ── 场景3:稀疏墙壁条件生成(子集 C)────────────────────────────────────
|
||||
if captured['C'] is not None:
|
||||
cap = captured['C']
|
||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
||||
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "sparse wall input")
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
@ -359,11 +363,11 @@ def validate(
|
||||
# ── 场景4:墙壁+入口条件生成(子集 D)───────────────────────────────────
|
||||
if captured['D'] is not None:
|
||||
cap = captured['D']
|
||||
raw, cond, z_q = cap['raw'], cap['masked'], cap['z_q']
|
||||
raw, cond, z_q, sc = cap['raw'], cap['masked'], cap['z_q'], cap['struct_cond']
|
||||
|
||||
real_img = label_image(make_map_image(raw[0], tile_dict), "ground truth")
|
||||
cond_img = label_image(make_map_image(cond[0], tile_dict), "wall+entrance input")
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond)
|
||||
gen_real = maskgit_generate(model_mg, z_q, init_map=cond, struct_cond=sc)
|
||||
gen_r_img = label_image(make_map_image(gen_real[0], tile_dict), "z_real gen")
|
||||
|
||||
row = [real_img, cond_img, gen_r_img] + _rand_gens(cond, N_Z_SAMPLES)
|
||||
@ -403,6 +407,7 @@ def train():
|
||||
num_layers=MG_LAYERS,
|
||||
map_size=MAP_SIZE,
|
||||
z_dropout=MG_Z_DROPOUT,
|
||||
struct_dropout=MG_STRUCT_DROPOUT,
|
||||
).to(device)
|
||||
|
||||
vq_params = sum(p.numel() for p in model_vq.parameters())
|
||||
@ -419,6 +424,8 @@ def train():
|
||||
dataset_val = GinkaVQDataset(
|
||||
args.validate,
|
||||
subset_weights=SUBSET_WEIGHTS,
|
||||
room_thresholds=dataset_train.room_th,
|
||||
branch_thresholds=dataset_train.branch_th,
|
||||
)
|
||||
dataloader_train = DataLoader(
|
||||
dataset_train, batch_size=BATCH_SIZE, shuffle=True,
|
||||
@ -481,8 +488,9 @@ def train():
|
||||
# 1. VQ-VAE 编码真实地图 → z_q
|
||||
z_q, _, vq_loss, commit_loss, entropy_loss = model_vq(raw_map) # z_q: [B, L, d_z]
|
||||
|
||||
# 2. MaskGIT 以掩码地图 + z 预测原始 tile
|
||||
logits = model_mg(masked_map, z_q) # [B, 169, C]
|
||||
# 2. MaskGIT 以掩码地图 + z + 结构标签预测原始 tile
|
||||
struct_cond = batch["struct_cond"].to(device) # [B, 4]
|
||||
logits = model_mg(masked_map, z_q, struct_cond=struct_cond) # [B, 169, C]
|
||||
|
||||
# 3. 只对被 mask 的位置计算 CE loss
|
||||
mask = (masked_map == MASK_TOKEN) # [B, 169] bool
|
||||
|
||||
Loading…
Reference in New Issue
Block a user