13 KiB
Stage1 二维空间感知改进设计文档
问题诊断
核心现象
第一阶段(floor/wall 骨架生成)质量显著劣于第二、三阶段。生成结果常见表现为:
- 墙壁分布破碎,缺乏连通性;
- 房间边界不完整,出现孤立墙块;
- 走廊未能形成闭合通道;
- 生成结果的空间拓扑结构与训练集分布偏差较大。
第二、三阶段(门/怪物/资源放置)问题相对较少,因为这些阶段的任务是"在已有结构上填充稀疏元素",对空间结构的整体一致性要求较低。
根本原因:一维位置编码与二维网格的结构失配
现有架构
GinkaMaskGIT 当前使用以下位置编码:
self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02)
这是一个纯一维可学习位置嵌入,将 13×13=169 个格子视为一条线性序列。GinkaVQVAE 中同样如此。
maskGIT.py 中的 Transformer 使用标准 nn.TransformerEncoder + nn.TransformerDecoder,注意力机制中没有任何二维空间偏置。
一维展平带来的结构失配
将 13×13 地图按行展平后,位置关系如下:
位置(0,12) → token 12
位置(1, 0) → token 13
这两个 token 在一维序列中相邻(距离 1),但在二维地图上相距 12 列(横跨整行)。而真正的二维邻居关系:
位置(0, 0) 和 位置(1, 0) 的二维距离 = 1 格
对应 token 0 和 token 13,一维距离 = 13
一维位置嵌入告诉模型"token 0 和 token 13 相距较远",但实际上它们是相邻的竖向邻格。注意力机制的相对偏置完全依赖位置嵌入的初始化,无法从一维嵌入中自动推断二维邻接关系。
为何 Stage1 特别敏感
Stage1 负责生成 floor/wall 骨架,这是整个地图中空间结构约束最强的层次:
- 墙壁需要形成封闭或半封闭的房间边界;
- 走廊需要是连通的、宽度一致的通道;
- 整体拓扑(房间数、对称性、外围走廊)需要全局一致。
上述约束全部是二维局部连通性约束:一个墙壁格子是否合理,取决于它的上下左右四个邻格,而非它前后若干 token。一维位置编码使模型必须从数据中隐式学习这种行列边界,代价高昂且泛化差。
相比之下,Stage2/3 的任务(在走廊或房间内散布门/怪物/资源)对全局结构的空间一致性要求较低,位置编码的精确性影响较小。
改进方案
方案 A:二维因式分解位置嵌入(推荐首选)
思路
将当前单一的一维位置嵌入替换为行嵌入与列嵌入的加和:
pos_embed[i, j] = row_embed[i] + col_embed[j]
这样,同一行的所有格子共享相同的行嵌入,同一列的所有格子共享相同的列嵌入。模型可以直接从嵌入中感知行列身份,而无需从一维序号中隐式推断。
具体实现
# 替换原有的 pos_embedding
self.row_embedding = nn.Parameter(torch.randn(1, MAP_H, d_model) * 0.02)
self.col_embedding = nn.Parameter(torch.randn(1, MAP_W, d_model) * 0.02)
前向传播中:
# map: [B, H*W]
row_idx = torch.arange(MAP_H, device=map.device).repeat_interleave(MAP_W)
col_idx = torch.arange(MAP_W, device=map.device).repeat(MAP_H)
pos = self.row_embedding[0, row_idx] + self.col_embedding[0, col_idx]
x = self.tile_embedding(map) + pos.unsqueeze(0)
也可以预计算展开后的索引并缓存。
特点
- 参数量变化:从
169 × d_model变为13 × d_model + 13 × d_model = 26 × d_model,显著减少,且参数共享有助于泛化; - 无需修改注意力机制,改动最小;
- 直接赋予模型行列语义,改进效果立竿见影。
方案 B:二维相对位置偏置(推荐次选,与 A 叠加)
思路
在注意力计算中,对每对 query-key 的打分加入一个可学习偏置,偏置由两个 token 的相对行列偏移量决定:
score(i, j) = (q_i · k_j) / sqrt(d) + B[Δrow, Δcol]
其中 Δrow = row(i) - row(j),Δcol = col(i) - col(j)。偏置表 B 的形状为 (2H-1, 2W-1) = (25, 25),每个注意力头各一张。
这种方式的核心优势:注意力打分天然理解"相邻格子应更强相关",模型无需从位置嵌入中隐式学习距离感。
具体实现
步骤一:预计算相对位置索引表
对于 13×13 的地图,预计算每对 token (i, j) 的相对位置,将二维偏移量映射为一维索引,供后续从偏置表中 gather:
def build_relative_position_index(H: int, W: int) -> torch.Tensor:
coords_h = torch.arange(H)
coords_w = torch.arange(W)
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij')) # [2, H, W]
coords_flat = coords.flatten(1) # [2, H*W]
rel = coords_flat[:, :, None] - coords_flat[:, None, :] # [2, H*W, H*W]
rel[0] += H - 1
rel[1] += W - 1
rel_index = rel[0] * (2 * W - 1) + rel[1] # [H*W, H*W]
return rel_index
步骤二:在模型中注册偏置表
需要自定义 SelfAttentionWithRPB,替换 nn.TransformerEncoderLayer 中的注意力:
class SelfAttentionWithRPB(nn.Module):
def __init__(self, d_model: int, nhead: int, H: int, W: int):
super().__init__()
self.nhead = nhead
self.d_head = d_model // nhead
self.scale = self.d_head ** -0.5
self.qkv = nn.Linear(d_model, d_model * 3)
self.out_proj = nn.Linear(d_model, d_model)
# 偏置表:(2H-1) * (2W-1) 个可能的相对位置,每个头各一组
self.rel_bias_table = nn.Parameter(
torch.zeros(nhead, (2 * H - 1) * (2 * W - 1))
)
rel_index = build_relative_position_index(H, W) # [H*W, H*W]
self.register_buffer('rel_index', rel_index.flatten()) # [H*W * H*W]
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.nhead, self.d_head)
q, k, v = qkv.unbind(2) # 各 [B, N, nhead, d_head]
q = q.permute(0, 2, 1, 3) # [B, nhead, N, d_head]
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, nhead, N, N]
# 从偏置表中取出对应偏置,reshape 成 [nhead, N, N]
bias = self.rel_bias_table[:, self.rel_index].reshape(self.nhead, N, N)
attn = attn + bias.unsqueeze(0)
attn = attn.softmax(dim=-1)
out = (attn @ v).permute(0, 2, 1, 3).reshape(B, N, C)
return self.out_proj(out)
步骤三:替换 TransformerEncoderLayer 中的注意力
由于 PyTorch 标准 TransformerEncoderLayer 不支持直接替换注意力实现,需要手写包含 RPB 的编码器层:
class RPBEncoderLayer(nn.Module):
def __init__(self, d_model: int, nhead: int, dim_ff: int, H: int, W: int):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.attn = SelfAttentionWithRPB(d_model, nhead, H, W)
self.norm2 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_ff),
nn.GELU(),
nn.Linear(dim_ff, d_model)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.attn(self.norm1(x))
x = x + self.ffn(self.norm2(x))
return x
参数量评估
对于 Stage1 MaskGIT(d_model=192, nhead=8, num_layers=6,H=W=13):
- 偏置表:每层
8 × 25 × 25 = 5000参数 - 共 6 层:30,000 参数
- 占总参数量的比例极低,但对 attention 的几何感知能力提升显著。
特点
- 与方案 A 互补,可叠加使用;
- 理论上最接近"正确"的二维注意力感应偏置;
- 实现较复杂,需要手写注意力层;
- 参数量增加极少。
方案 C:轴向注意力(Axial Attention)
思路
将每一个标准自注意力层替换为两个顺序执行的注意力:
- 行轴注意力:每一行内的格子互相注意,13 个 group,每组 13 个 token;
- 列轴注意力:每一列内的格子互相注意,13 个 group,每组 13 个 token。
两种注意力交替叠加:Row → Col → Row → Col → ...
标准自注意力:O((H*W)²) = O(169²) = 28,561
轴向注意力: O(H * W²) + O(W * H²) = O(H*W*(H+W)) = O(169*26) ≈ 4,394
复杂度大幅下降,且强制模型以行/列为单位建立空间关联。
具体实现
class AxialAttention(nn.Module):
def __init__(self, d_model: int, nhead: int, H: int, W: int):
super().__init__()
self.H = H
self.W = W
self.row_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.col_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True)
self.norm_row = nn.LayerNorm(d_model)
self.norm_col = nn.LayerNorm(d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
H, W = self.H, self.W
x = x.reshape(B, H, W, C)
# 行注意力:每行的 W 个 token 互相注意
x_row = x.reshape(B * H, W, C)
x_row, _ = self.row_attn(x_row, x_row, x_row)
x = x + x_row.reshape(B, H, W, C)
x = self.norm_row(x)
# 列注意力:每列的 H 个 token 互相注意
x_col = x.permute(0, 2, 1, 3).reshape(B * W, H, C)
x_col, _ = self.col_attn(x_col, x_col, x_col)
x = x + x_col.reshape(B, W, H, C).permute(0, 2, 1, 3)
x = self.norm_col(x)
return x.reshape(B, N, C)
class AxialEncoderLayer(nn.Module):
def __init__(self, d_model: int, nhead: int, dim_ff: int, H: int, W: int):
super().__init__()
self.axial_attn = AxialAttention(d_model, nhead, H, W)
self.norm_ff = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, dim_ff),
nn.GELU(),
nn.Linear(dim_ff, d_model)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.axial_attn(x)
x = x + self.ffn(self.norm_ff(x))
return x
特点
- 最强的二维结构归纳偏置:明确区分行内关系和列内关系;
- 改造最彻底,需要替换整个 Transformer 编码器;
- 计算量低于标准全局注意力;
- 不适合 cross-attention 部分(decoder 中 map token cross-attend z 的部分保持不变)。
方案 D:双轴并行输入(Dual-Axis Input)
思路
同一张地图展平两次,分别按行优先和列优先展平,两份序列并行送入各自的嵌入+编码器,最终在 token 维度相加合并:
地图 [H, W]
├─ 行优先展平 → [H*W] → Embedding + 行位置编码 → 编码器 A → [H*W, d_model]
└─ 列优先展平 → [W*H] → Embedding + 列位置编码 → 编码器 B → [W*H, d_model]
↓ 重排到行优先顺序
相加合并 → [H*W, d_model] → Decoder + z → logits
两个编码器可以共享权重(只有位置编码不同),以减少参数量。
特点
- 无需修改注意力机制,复用现有 Transformer;
- 计算量加倍(两次编码),但可通过共享权重缓解;
- 同时为模型提供横向和纵向的序列上下文;
- 参数复用程度较高(共享编码器权重时)。
方案对比
| 方案 | 改动范围 | 额外参数量 | 二维感知能力 | 实现复杂度 |
|---|---|---|---|---|
| A:二维因式位置嵌入 | 位置嵌入层 | 减少(共享) | 中等 | 低 |
| B:二维相对位置偏置 | 注意力层 | 极少(~30k) | 强 | 中等 |
| C:轴向注意力 | 整个编码器 | 基本不变 | 最强 | 高 |
| D:双轴并行输入 | 输入与编码器 | 按编码器大小翻倍(共享权重则不变) | 中等 | 中等 |
推荐实施策略
第一步:替换位置嵌入(方案 A)
优先实施方案 A,这是改动最小、风险最低的基础改进。仅需修改 GinkaMaskGIT 和 GinkaVQVAE 中的 pos_embedding 初始化与使用方式。
修改范围:
ginka/maskGIT/model.py:替换pos_embedding,调整forward中的位置嵌入加法;ginka/vqvae/model.py:同上。
第二步:叠加二维相对位置偏置(方案 B)
在方案 A 的基础上,为 GinkaMaskGIT 的编码器部分叠加 RPB,只需新增一个自定义 RPBEncoderLayer,替换 maskGIT.py 中 Transformer.encoder 的层类型。
VQ-VAE 编码器的 RPB 改造优先级较低(VQ-VAE 负责全图压缩,对局部连通性感知需求低于 MaskGIT)。
可选第三步:轴向注意力替换(方案 C)
若前两步改进后 Stage1 质量仍不满足要求,可进一步将 GinkaMaskGIT 的编码器改为轴向注意力。由于 decoder(cross-attention 部分)不涉及空间 token,无需改动。
其他注意事项
- 上述改动仅针对 Stage1 MaskGIT。Stage2/3 可以同步修改,也可以保持原结构,视实际效果决定;
- 改动后需重置训练,位置嵌入的结构变化会使旧检查点不兼容(shape 不匹配);
- 验证时重点观察生成地图中墙壁的连通性、房间闭合度和整体拓扑是否改善。