ginka-generator/docs/2d-spatial-awareness-design.md

355 lines
13 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Stage1 二维空间感知改进设计文档
## 问题诊断
### 核心现象
第一阶段floor/wall 骨架生成)质量显著劣于第二、三阶段。生成结果常见表现为:
- 墙壁分布破碎,缺乏连通性;
- 房间边界不完整,出现孤立墙块;
- 走廊未能形成闭合通道;
- 生成结果的空间拓扑结构与训练集分布偏差较大。
第二、三阶段(门/怪物/资源放置)问题相对较少,因为这些阶段的任务是"在已有结构上填充稀疏元素",对空间结构的整体一致性要求较低。
### 根本原因:一维位置编码与二维网格的结构失配
#### 现有架构
`GinkaMaskGIT` 当前使用以下位置编码:
```python
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
```
这是一个纯一维可学习位置嵌入,将 13×13=169 个格子视为一条线性序列。`GinkaVQVAE` 中同样如此。
`maskGIT.py` 中的 `Transformer` 使用标准 `nn.TransformerEncoder` + `nn.TransformerDecoder`,注意力机制中没有任何二维空间偏置。
#### 一维展平带来的结构失配
将 13×13 地图按行展平后,位置关系如下:
```
位置(0,12) → token 12
位置(1, 0) → token 13
```
这两个 token 在一维序列中相邻(距离 1但在二维地图上相距 12 列(横跨整行)。而真正的二维邻居关系:
```
位置(0, 0) 和 位置(1, 0) 的二维距离 = 1 格
对应 token 0 和 token 13一维距离 = 13
```
一维位置嵌入告诉模型"token 0 和 token 13 相距较远",但实际上它们是相邻的竖向邻格。注意力机制的相对偏置完全依赖位置嵌入的初始化,无法从一维嵌入中自动推断二维邻接关系。
#### 为何 Stage1 特别敏感
Stage1 负责生成 floor/wall 骨架,这是整个地图中**空间结构约束最强**的层次:
- 墙壁需要形成封闭或半封闭的房间边界;
- 走廊需要是连通的、宽度一致的通道;
- 整体拓扑(房间数、对称性、外围走廊)需要全局一致。
上述约束全部是**二维局部连通性约束**:一个墙壁格子是否合理,取决于它的上下左右四个邻格,而非它前后若干 token。一维位置编码使模型必须从数据中隐式学习这种行列边界代价高昂且泛化差。
相比之下Stage2/3 的任务(在走廊或房间内散布门/怪物/资源)对全局结构的空间一致性要求较低,位置编码的精确性影响较小。
---
## 改进方案
### 方案 A二维因式分解位置嵌入推荐首选
#### 思路
将当前单一的一维位置嵌入替换为行嵌入与列嵌入的加和:
```
pos_embed[i, j] = row_embed[i] + col_embed[j]
```
这样,同一行的所有格子共享相同的行嵌入,同一列的所有格子共享相同的列嵌入。模型可以直接从嵌入中感知行列身份,而无需从一维序号中隐式推断。
#### 具体实现
```python
# 替换原有的 pos_embedding
self.row_embedding = nn.Parameter(torch.randn(1, MAP_H, d_model) * 0.02)
self.col_embedding = nn.Parameter(torch.randn(1, MAP_W, d_model) * 0.02)
```
前向传播中:
```python
# map: [B, H*W]
row_idx = torch.arange(MAP_H, device=map.device).repeat_interleave(MAP_W)
col_idx = torch.arange(MAP_W, device=map.device).repeat(MAP_H)
pos = self.row_embedding[0, row_idx] + self.col_embedding[0, col_idx]
x = self.tile_embedding(map) + pos.unsqueeze(0)
```
也可以预计算展开后的索引并缓存。
#### 特点
- 参数量变化:从 `169 × d_model` 变为 `13 × d_model + 13 × d_model = 26 × d_model`,显著减少,且参数共享有助于泛化;
- 无需修改注意力机制,改动最小;
- 直接赋予模型行列语义,改进效果立竿见影。
---
### 方案 B二维相对位置偏置推荐次选与 A 叠加)
#### 思路
在注意力计算中,对每对 query-key 的打分加入一个可学习偏置,偏置由两个 token 的**相对行列偏移量**决定:
```
score(i, j) = (q_i · k_j) / sqrt(d) + B[Δrow, Δcol]
```
其中 `Δrow = row(i) - row(j)``Δcol = col(i) - col(j)`。偏置表 B 的形状为 `(2H-1, 2W-1)` = `(25, 25)`,每个注意力头各一张。
这种方式的核心优势:注意力打分天然理解"相邻格子应更强相关",模型无需从位置嵌入中隐式学习距离感。
#### 具体实现
**步骤一:预计算相对位置索引表**
对于 13×13 的地图,预计算每对 token (i, j) 的相对位置,将二维偏移量映射为一维索引,供后续从偏置表中 gather
```python
def build_relative_position_index(H: int, W: int) -> torch.Tensor:
coords_h = torch.arange(H)
coords_w = torch.arange(W)
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij')) # [2, H, W]
coords_flat = coords.flatten(1) # [2, H*W]
rel = coords_flat[:, :, None] - coords_flat[:, None, :] # [2, H*W, H*W]
rel[0] += H - 1
rel[1] += W - 1
rel_index = rel[0] * (2 * W - 1) + rel[1] # [H*W, H*W]
return rel_index
```
**步骤二:在模型中注册偏置表**
需要自定义 `SelfAttentionWithRPB`,替换 `nn.TransformerEncoderLayer` 中的注意力:
```python
class SelfAttentionWithRPB(nn.Module):
def __init__(self, d_model: int, nhead: int, H: int, W: int):
super().__init__()
self.nhead = nhead
self.d_head = d_model // nhead
self.scale = self.d_head ** -0.5
self.qkv = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
# 偏置表:(2H-1) * (2W-1) 个可能的相对位置,每个头各一组
self.rel_bias_table = nn.Parameter(
torch.zeros(nhead, (2 * H - 1) * (2 * W - 1))
)
rel_index = build_relative_position_index(H, W) # [H*W, H*W]
self.register_buffer('rel_index', rel_index.flatten()) # [H*W * H*W]
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.nhead, self.d_head)
q, k, v = qkv.unbind(2) # 各 [B, N, nhead, d_head]
q = q.permute(0, 2, 1, 3) # [B, nhead, N, d_head]
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, nhead, N, N]
# 从偏置表中取出对应偏置reshape 成 [nhead, N, N]
bias = self.rel_bias_table[:, self.rel_index].reshape(self.nhead, N, N)
attn = attn + bias.unsqueeze(0)
attn = attn.softmax(dim=-1)
out = (attn @ v).permute(0, 2, 1, 3).reshape(B, N, C)
return self.out_proj(out)
```
**步骤三:替换 TransformerEncoderLayer 中的注意力**
由于 PyTorch 标准 `TransformerEncoderLayer` 不支持直接替换注意力实现,需要手写包含 RPB 的编码器层:
```python
class RPBEncoderLayer(nn.Module):
def __init__(self, d_model: int, nhead: int, dim_ff: int, H: int, W: int):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = SelfAttentionWithRPB(d_model, nhead, H, W)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_ff),
nn.GELU(),
nn.Linear(dim_ff, d_model)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
```
#### 参数量评估
对于 Stage1 MaskGITd_model=192, nhead=8, num_layers=6H=W=13
- 偏置表:每层 `8 × 25 × 25 = 5000` 参数
- 共 6 层30,000 参数
- 占总参数量的比例极低,但对 attention 的几何感知能力提升显著。
#### 特点
- 与方案 A 互补,可叠加使用;
- 理论上最接近"正确"的二维注意力感应偏置;
- 实现较复杂,需要手写注意力层;
- 参数量增加极少。
---
### 方案 C轴向注意力Axial Attention
#### 思路
将每一个标准自注意力层替换为两个顺序执行的注意力:
1. **行轴注意力**每一行内的格子互相注意13 个 group每组 13 个 token
2. **列轴注意力**每一列内的格子互相注意13 个 group每组 13 个 token。
两种注意力交替叠加:`Row → Col → Row → Col → ...`
```
标准自注意力O((H*W)²) = O(169²) = 28,561
轴向注意力: O(H * W²) + O(W * H²) = O(H*W*(H+W)) = O(169*26) ≈ 4,394
```
复杂度大幅下降,且强制模型以行/列为单位建立空间关联。
#### 具体实现
```python
class AxialAttention(nn.Module):
def __init__(self, d_model: int, nhead: int, H: int, W: int):
super().__init__()
self.H = H
self.W = W
self.row_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.col_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.norm_row = nn.LayerNorm(d_model)
self.norm_col = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
H, W = self.H, self.W
x = x.reshape(B, H, W, C)
# 行注意力:每行的 W 个 token 互相注意
x_row = x.reshape(B * H, W, C)
x_row, _ = self.row_attn(x_row, x_row, x_row)
x = x + x_row.reshape(B, H, W, C)
x = self.norm_row(x)
# 列注意力:每列的 H 个 token 互相注意
x_col = x.permute(0, 2, 1, 3).reshape(B * W, H, C)
x_col, _ = self.col_attn(x_col, x_col, x_col)
x = x + x_col.reshape(B, W, H, C).permute(0, 2, 1, 3)
x = self.norm_col(x)
return x.reshape(B, N, C)
class AxialEncoderLayer(nn.Module):
def __init__(self, d_model: int, nhead: int, dim_ff: int, H: int, W: int):
super().__init__()
self.axial_attn = AxialAttention(d_model, nhead, H, W)
self.norm_ff = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_ff),
nn.GELU(),
nn.Linear(dim_ff, d_model)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.axial_attn(x)
x = x + self.ffn(self.norm_ff(x))
return x
```
#### 特点
- 最强的二维结构归纳偏置:明确区分行内关系和列内关系;
- 改造最彻底,需要替换整个 Transformer 编码器;
- 计算量低于标准全局注意力;
- 不适合 cross-attention 部分decoder 中 map token cross-attend z 的部分保持不变)。
---
### 方案 D双轴并行输入Dual-Axis Input
#### 思路
同一张地图展平两次,分别按行优先和列优先展平,两份序列并行送入各自的嵌入+编码器,最终在 token 维度相加合并:
```
地图 [H, W]
├─ 行优先展平 → [H*W] → Embedding + 行位置编码 → 编码器 A → [H*W, d_model]
└─ 列优先展平 → [W*H] → Embedding + 列位置编码 → 编码器 B → [W*H, d_model]
↓ 重排到行优先顺序
相加合并 → [H*W, d_model] → Decoder + z → logits
```
两个编码器可以共享权重(只有位置编码不同),以减少参数量。
#### 特点
- 无需修改注意力机制,复用现有 Transformer
- 计算量加倍(两次编码),但可通过共享权重缓解;
- 同时为模型提供横向和纵向的序列上下文;
- 参数复用程度较高(共享编码器权重时)。
---
## 方案对比
| 方案 | 改动范围 | 额外参数量 | 二维感知能力 | 实现复杂度 |
| ------------------- | ------------ | ---------------------------------- | ------------ | ---------- |
| A二维因式位置嵌入 | 位置嵌入层 | 减少(共享) | 中等 | 低 |
| B二维相对位置偏置 | 注意力层 | 极少(~30k | 强 | 中等 |
| C轴向注意力 | 整个编码器 | 基本不变 | 最强 | 高 |
| D双轴并行输入 | 输入与编码器 | 按编码器大小翻倍(共享权重则不变) | 中等 | 中等 |
---
## 推荐实施策略
### 第一步:替换位置嵌入(方案 A
优先实施方案 A这是改动最小、风险最低的基础改进。仅需修改 `GinkaMaskGIT``GinkaVQVAE` 中的 `pos_embedding` 初始化与使用方式。
修改范围:
- `ginka/maskGIT/model.py`:替换 `pos_embedding`,调整 `forward` 中的位置嵌入加法;
- `ginka/vqvae/model.py`:同上。
### 第二步:叠加二维相对位置偏置(方案 B
在方案 A 的基础上,为 `GinkaMaskGIT` 的编码器部分叠加 RPB只需新增一个自定义 `RPBEncoderLayer`,替换 `maskGIT.py``Transformer.encoder` 的层类型。
VQ-VAE 编码器的 RPB 改造优先级较低VQ-VAE 负责全图压缩,对局部连通性感知需求低于 MaskGIT
### 可选第三步:轴向注意力替换(方案 C
若前两步改进后 Stage1 质量仍不满足要求,可进一步将 `GinkaMaskGIT` 的编码器改为轴向注意力。由于 decodercross-attention 部分)不涉及空间 token无需改动。
### 其他注意事项
- 上述改动仅针对 **Stage1 MaskGIT**。Stage2/3 可以同步修改,也可以保持原结构,视实际效果决定;
- 改动后需重置训练位置嵌入的结构变化会使旧检查点不兼容shape 不匹配);
- 验证时重点观察生成地图中墙壁的连通性、房间闭合度和整体拓扑是否改善。