mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 14:31:11 +08:00
355 lines
13 KiB
Markdown
355 lines
13 KiB
Markdown
# Stage1 二维空间感知改进设计文档
|
||
|
||
## 问题诊断
|
||
|
||
### 核心现象
|
||
|
||
第一阶段(floor/wall 骨架生成)质量显著劣于第二、三阶段。生成结果常见表现为:
|
||
|
||
- 墙壁分布破碎,缺乏连通性;
|
||
- 房间边界不完整,出现孤立墙块;
|
||
- 走廊未能形成闭合通道;
|
||
- 生成结果的空间拓扑结构与训练集分布偏差较大。
|
||
|
||
第二、三阶段(门/怪物/资源放置)问题相对较少,因为这些阶段的任务是"在已有结构上填充稀疏元素",对空间结构的整体一致性要求较低。
|
||
|
||
### 根本原因:一维位置编码与二维网格的结构失配
|
||
|
||
#### 现有架构
|
||
|
||
`GinkaMaskGIT` 当前使用以下位置编码:
|
||
|
||
```python
|
||
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
|
||
```
|
||
|
||
这是一个纯一维可学习位置嵌入,将 13×13=169 个格子视为一条线性序列。`GinkaVQVAE` 中同样如此。
|
||
|
||
`maskGIT.py` 中的 `Transformer` 使用标准 `nn.TransformerEncoder` + `nn.TransformerDecoder`,注意力机制中没有任何二维空间偏置。
|
||
|
||
#### 一维展平带来的结构失配
|
||
|
||
将 13×13 地图按行展平后,位置关系如下:
|
||
|
||
```
|
||
位置(0,12) → token 12
|
||
位置(1, 0) → token 13
|
||
```
|
||
|
||
这两个 token 在一维序列中相邻(距离 1),但在二维地图上相距 12 列(横跨整行)。而真正的二维邻居关系:
|
||
|
||
```
|
||
位置(0, 0) 和 位置(1, 0) 的二维距离 = 1 格
|
||
对应 token 0 和 token 13,一维距离 = 13
|
||
```
|
||
|
||
一维位置嵌入告诉模型"token 0 和 token 13 相距较远",但实际上它们是相邻的竖向邻格。注意力机制的相对偏置完全依赖位置嵌入的初始化,无法从一维嵌入中自动推断二维邻接关系。
|
||
|
||
#### 为何 Stage1 特别敏感
|
||
|
||
Stage1 负责生成 floor/wall 骨架,这是整个地图中**空间结构约束最强**的层次:
|
||
|
||
- 墙壁需要形成封闭或半封闭的房间边界;
|
||
- 走廊需要是连通的、宽度一致的通道;
|
||
- 整体拓扑(房间数、对称性、外围走廊)需要全局一致。
|
||
|
||
上述约束全部是**二维局部连通性约束**:一个墙壁格子是否合理,取决于它的上下左右四个邻格,而非它前后若干 token。一维位置编码使模型必须从数据中隐式学习这种行列边界,代价高昂且泛化差。
|
||
|
||
相比之下,Stage2/3 的任务(在走廊或房间内散布门/怪物/资源)对全局结构的空间一致性要求较低,位置编码的精确性影响较小。
|
||
|
||
---
|
||
|
||
## 改进方案
|
||
|
||
### 方案 A:二维因式分解位置嵌入(推荐首选)
|
||
|
||
#### 思路
|
||
|
||
将当前单一的一维位置嵌入替换为行嵌入与列嵌入的加和:
|
||
|
||
```
|
||
pos_embed[i, j] = row_embed[i] + col_embed[j]
|
||
```
|
||
|
||
这样,同一行的所有格子共享相同的行嵌入,同一列的所有格子共享相同的列嵌入。模型可以直接从嵌入中感知行列身份,而无需从一维序号中隐式推断。
|
||
|
||
#### 具体实现
|
||
|
||
```python
|
||
# 替换原有的 pos_embedding
|
||
self.row_embedding = nn.Parameter(torch.randn(1, MAP_H, d_model) * 0.02)
|
||
self.col_embedding = nn.Parameter(torch.randn(1, MAP_W, d_model) * 0.02)
|
||
```
|
||
|
||
前向传播中:
|
||
|
||
```python
|
||
# map: [B, H*W]
|
||
row_idx = torch.arange(MAP_H, device=map.device).repeat_interleave(MAP_W)
|
||
col_idx = torch.arange(MAP_W, device=map.device).repeat(MAP_H)
|
||
pos = self.row_embedding[0, row_idx] + self.col_embedding[0, col_idx]
|
||
x = self.tile_embedding(map) + pos.unsqueeze(0)
|
||
```
|
||
|
||
也可以预计算展开后的索引并缓存。
|
||
|
||
#### 特点
|
||
|
||
- 参数量变化:从 `169 × d_model` 变为 `13 × d_model + 13 × d_model = 26 × d_model`,显著减少,且参数共享有助于泛化;
|
||
- 无需修改注意力机制,改动最小;
|
||
- 直接赋予模型行列语义,改进效果立竿见影。
|
||
|
||
---
|
||
|
||
### 方案 B:二维相对位置偏置(推荐次选,与 A 叠加)
|
||
|
||
#### 思路
|
||
|
||
在注意力计算中,对每对 query-key 的打分加入一个可学习偏置,偏置由两个 token 的**相对行列偏移量**决定:
|
||
|
||
```
|
||
score(i, j) = (q_i · k_j) / sqrt(d) + B[Δrow, Δcol]
|
||
```
|
||
|
||
其中 `Δrow = row(i) - row(j)`,`Δcol = col(i) - col(j)`。偏置表 B 的形状为 `(2H-1, 2W-1)` = `(25, 25)`,每个注意力头各一张。
|
||
|
||
这种方式的核心优势:注意力打分天然理解"相邻格子应更强相关",模型无需从位置嵌入中隐式学习距离感。
|
||
|
||
#### 具体实现
|
||
|
||
**步骤一:预计算相对位置索引表**
|
||
|
||
对于 13×13 的地图,预计算每对 token (i, j) 的相对位置,将二维偏移量映射为一维索引,供后续从偏置表中 gather:
|
||
|
||
```python
|
||
def build_relative_position_index(H: int, W: int) -> torch.Tensor:
|
||
coords_h = torch.arange(H)
|
||
coords_w = torch.arange(W)
|
||
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij')) # [2, H, W]
|
||
coords_flat = coords.flatten(1) # [2, H*W]
|
||
rel = coords_flat[:, :, None] - coords_flat[:, None, :] # [2, H*W, H*W]
|
||
rel[0] += H - 1
|
||
rel[1] += W - 1
|
||
rel_index = rel[0] * (2 * W - 1) + rel[1] # [H*W, H*W]
|
||
return rel_index
|
||
```
|
||
|
||
**步骤二:在模型中注册偏置表**
|
||
|
||
需要自定义 `SelfAttentionWithRPB`,替换 `nn.TransformerEncoderLayer` 中的注意力:
|
||
|
||
```python
|
||
class SelfAttentionWithRPB(nn.Module):
|
||
def __init__(self, d_model: int, nhead: int, H: int, W: int):
|
||
super().__init__()
|
||
self.nhead = nhead
|
||
self.d_head = d_model // nhead
|
||
self.scale = self.d_head ** -0.5
|
||
self.qkv = nn.Linear(d_model, d_model * 3)
|
||
self.out_proj = nn.Linear(d_model, d_model)
|
||
# 偏置表:(2H-1) * (2W-1) 个可能的相对位置,每个头各一组
|
||
self.rel_bias_table = nn.Parameter(
|
||
torch.zeros(nhead, (2 * H - 1) * (2 * W - 1))
|
||
)
|
||
rel_index = build_relative_position_index(H, W) # [H*W, H*W]
|
||
self.register_buffer('rel_index', rel_index.flatten()) # [H*W * H*W]
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
B, N, C = x.shape
|
||
qkv = self.qkv(x).reshape(B, N, 3, self.nhead, self.d_head)
|
||
q, k, v = qkv.unbind(2) # 各 [B, N, nhead, d_head]
|
||
q = q.permute(0, 2, 1, 3) # [B, nhead, N, d_head]
|
||
k = k.permute(0, 2, 1, 3)
|
||
v = v.permute(0, 2, 1, 3)
|
||
|
||
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, nhead, N, N]
|
||
|
||
# 从偏置表中取出对应偏置,reshape 成 [nhead, N, N]
|
||
bias = self.rel_bias_table[:, self.rel_index].reshape(self.nhead, N, N)
|
||
attn = attn + bias.unsqueeze(0)
|
||
|
||
attn = attn.softmax(dim=-1)
|
||
out = (attn @ v).permute(0, 2, 1, 3).reshape(B, N, C)
|
||
return self.out_proj(out)
|
||
```
|
||
|
||
**步骤三:替换 TransformerEncoderLayer 中的注意力**
|
||
|
||
由于 PyTorch 标准 `TransformerEncoderLayer` 不支持直接替换注意力实现,需要手写包含 RPB 的编码器层:
|
||
|
||
```python
|
||
class RPBEncoderLayer(nn.Module):
|
||
def __init__(self, d_model: int, nhead: int, dim_ff: int, H: int, W: int):
|
||
super().__init__()
|
||
self.norm1 = nn.LayerNorm(d_model)
|
||
self.attn = SelfAttentionWithRPB(d_model, nhead, H, W)
|
||
self.norm2 = nn.LayerNorm(d_model)
|
||
self.ffn = nn.Sequential(
|
||
nn.Linear(d_model, dim_ff),
|
||
nn.GELU(),
|
||
nn.Linear(dim_ff, d_model)
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x = x + self.attn(self.norm1(x))
|
||
x = x + self.ffn(self.norm2(x))
|
||
return x
|
||
```
|
||
|
||
#### 参数量评估
|
||
|
||
对于 Stage1 MaskGIT(d_model=192, nhead=8, num_layers=6,H=W=13):
|
||
|
||
- 偏置表:每层 `8 × 25 × 25 = 5000` 参数
|
||
- 共 6 层:30,000 参数
|
||
- 占总参数量的比例极低,但对 attention 的几何感知能力提升显著。
|
||
|
||
#### 特点
|
||
|
||
- 与方案 A 互补,可叠加使用;
|
||
- 理论上最接近"正确"的二维注意力感应偏置;
|
||
- 实现较复杂,需要手写注意力层;
|
||
- 参数量增加极少。
|
||
|
||
---
|
||
|
||
### 方案 C:轴向注意力(Axial Attention)
|
||
|
||
#### 思路
|
||
|
||
将每一个标准自注意力层替换为两个顺序执行的注意力:
|
||
|
||
1. **行轴注意力**:每一行内的格子互相注意,13 个 group,每组 13 个 token;
|
||
2. **列轴注意力**:每一列内的格子互相注意,13 个 group,每组 13 个 token。
|
||
|
||
两种注意力交替叠加:`Row → Col → Row → Col → ...`
|
||
|
||
```
|
||
标准自注意力:O((H*W)²) = O(169²) = 28,561
|
||
轴向注意力: O(H * W²) + O(W * H²) = O(H*W*(H+W)) = O(169*26) ≈ 4,394
|
||
```
|
||
|
||
复杂度大幅下降,且强制模型以行/列为单位建立空间关联。
|
||
|
||
#### 具体实现
|
||
|
||
```python
|
||
class AxialAttention(nn.Module):
|
||
def __init__(self, d_model: int, nhead: int, H: int, W: int):
|
||
super().__init__()
|
||
self.H = H
|
||
self.W = W
|
||
self.row_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
||
self.col_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
|
||
self.norm_row = nn.LayerNorm(d_model)
|
||
self.norm_col = nn.LayerNorm(d_model)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
B, N, C = x.shape
|
||
H, W = self.H, self.W
|
||
x = x.reshape(B, H, W, C)
|
||
|
||
# 行注意力:每行的 W 个 token 互相注意
|
||
x_row = x.reshape(B * H, W, C)
|
||
x_row, _ = self.row_attn(x_row, x_row, x_row)
|
||
x = x + x_row.reshape(B, H, W, C)
|
||
x = self.norm_row(x)
|
||
|
||
# 列注意力:每列的 H 个 token 互相注意
|
||
x_col = x.permute(0, 2, 1, 3).reshape(B * W, H, C)
|
||
x_col, _ = self.col_attn(x_col, x_col, x_col)
|
||
x = x + x_col.reshape(B, W, H, C).permute(0, 2, 1, 3)
|
||
x = self.norm_col(x)
|
||
|
||
return x.reshape(B, N, C)
|
||
|
||
|
||
class AxialEncoderLayer(nn.Module):
|
||
def __init__(self, d_model: int, nhead: int, dim_ff: int, H: int, W: int):
|
||
super().__init__()
|
||
self.axial_attn = AxialAttention(d_model, nhead, H, W)
|
||
self.norm_ff = nn.LayerNorm(d_model)
|
||
self.ffn = nn.Sequential(
|
||
nn.Linear(d_model, dim_ff),
|
||
nn.GELU(),
|
||
nn.Linear(dim_ff, d_model)
|
||
)
|
||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||
x = self.axial_attn(x)
|
||
x = x + self.ffn(self.norm_ff(x))
|
||
return x
|
||
```
|
||
|
||
#### 特点
|
||
|
||
- 最强的二维结构归纳偏置:明确区分行内关系和列内关系;
|
||
- 改造最彻底,需要替换整个 Transformer 编码器;
|
||
- 计算量低于标准全局注意力;
|
||
- 不适合 cross-attention 部分(decoder 中 map token cross-attend z 的部分保持不变)。
|
||
|
||
---
|
||
|
||
### 方案 D:双轴并行输入(Dual-Axis Input)
|
||
|
||
#### 思路
|
||
|
||
同一张地图展平两次,分别按行优先和列优先展平,两份序列并行送入各自的嵌入+编码器,最终在 token 维度相加合并:
|
||
|
||
```
|
||
地图 [H, W]
|
||
├─ 行优先展平 → [H*W] → Embedding + 行位置编码 → 编码器 A → [H*W, d_model]
|
||
└─ 列优先展平 → [W*H] → Embedding + 列位置编码 → 编码器 B → [W*H, d_model]
|
||
↓ 重排到行优先顺序
|
||
相加合并 → [H*W, d_model] → Decoder + z → logits
|
||
```
|
||
|
||
两个编码器可以共享权重(只有位置编码不同),以减少参数量。
|
||
|
||
#### 特点
|
||
|
||
- 无需修改注意力机制,复用现有 Transformer;
|
||
- 计算量加倍(两次编码),但可通过共享权重缓解;
|
||
- 同时为模型提供横向和纵向的序列上下文;
|
||
- 参数复用程度较高(共享编码器权重时)。
|
||
|
||
---
|
||
|
||
## 方案对比
|
||
|
||
| 方案 | 改动范围 | 额外参数量 | 二维感知能力 | 实现复杂度 |
|
||
| ------------------- | ------------ | ---------------------------------- | ------------ | ---------- |
|
||
| A:二维因式位置嵌入 | 位置嵌入层 | 减少(共享) | 中等 | 低 |
|
||
| B:二维相对位置偏置 | 注意力层 | 极少(~30k) | 强 | 中等 |
|
||
| C:轴向注意力 | 整个编码器 | 基本不变 | 最强 | 高 |
|
||
| D:双轴并行输入 | 输入与编码器 | 按编码器大小翻倍(共享权重则不变) | 中等 | 中等 |
|
||
|
||
---
|
||
|
||
## 推荐实施策略
|
||
|
||
### 第一步:替换位置嵌入(方案 A)
|
||
|
||
优先实施方案 A,这是改动最小、风险最低的基础改进。仅需修改 `GinkaMaskGIT` 和 `GinkaVQVAE` 中的 `pos_embedding` 初始化与使用方式。
|
||
|
||
修改范围:
|
||
|
||
- `ginka/maskGIT/model.py`:替换 `pos_embedding`,调整 `forward` 中的位置嵌入加法;
|
||
- `ginka/vqvae/model.py`:同上。
|
||
|
||
### 第二步:叠加二维相对位置偏置(方案 B)
|
||
|
||
在方案 A 的基础上,为 `GinkaMaskGIT` 的编码器部分叠加 RPB,只需新增一个自定义 `RPBEncoderLayer`,替换 `maskGIT.py` 中 `Transformer.encoder` 的层类型。
|
||
|
||
VQ-VAE 编码器的 RPB 改造优先级较低(VQ-VAE 负责全图压缩,对局部连通性感知需求低于 MaskGIT)。
|
||
|
||
### 可选第三步:轴向注意力替换(方案 C)
|
||
|
||
若前两步改进后 Stage1 质量仍不满足要求,可进一步将 `GinkaMaskGIT` 的编码器改为轴向注意力。由于 decoder(cross-attention 部分)不涉及空间 token,无需改动。
|
||
|
||
### 其他注意事项
|
||
|
||
- 上述改动仅针对 **Stage1 MaskGIT**。Stage2/3 可以同步修改,也可以保持原结构,视实际效果决定;
|
||
- 改动后需重置训练,位置嵌入的结构变化会使旧检查点不兼容(shape 不匹配);
|
||
- 验证时重点观察生成地图中墙壁的连通性、房间闭合度和整体拓扑是否改善。
|