From 3df7d595754977564e61a53d0c0ab0604b5adb21 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 14 May 2026 13:48:32 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BA=8C=E7=BB=B4=E7=BD=91=E6=A0=BC?= =?UTF-8?q?=E6=B3=A8=E6=84=8F=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/2d-spatial-awareness-design.md | 354 ++++++++++++++++++++++++++++ ginka/dataset.py | 8 +- ginka/maskGIT/model.py | 29 +-- ginka/train_seperated.py | 28 ++- ginka/vqvae/model.py | 23 +- 5 files changed, 404 insertions(+), 38 deletions(-) create mode 100644 docs/2d-spatial-awareness-design.md diff --git a/docs/2d-spatial-awareness-design.md b/docs/2d-spatial-awareness-design.md new file mode 100644 index 0000000..633ebab --- /dev/null +++ b/docs/2d-spatial-awareness-design.md @@ -0,0 +1,354 @@ +# Stage1 二维空间感知改进设计文档 + +## 问题诊断 + +### 核心现象 + +第一阶段(floor/wall 骨架生成)质量显著劣于第二、三阶段。生成结果常见表现为: + +- 墙壁分布破碎,缺乏连通性; +- 房间边界不完整,出现孤立墙块; +- 走廊未能形成闭合通道; +- 生成结果的空间拓扑结构与训练集分布偏差较大。 + +第二、三阶段(门/怪物/资源放置)问题相对较少,因为这些阶段的任务是"在已有结构上填充稀疏元素",对空间结构的整体一致性要求较低。 + +### 根本原因:一维位置编码与二维网格的结构失配 + +#### 现有架构 + +`GinkaMaskGIT` 当前使用以下位置编码: + +```python +self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02) +``` + +这是一个纯一维可学习位置嵌入,将 13×13=169 个格子视为一条线性序列。`GinkaVQVAE` 中同样如此。 + +`maskGIT.py` 中的 `Transformer` 使用标准 `nn.TransformerEncoder` + `nn.TransformerDecoder`,注意力机制中没有任何二维空间偏置。 + +#### 一维展平带来的结构失配 + +将 13×13 地图按行展平后,位置关系如下: + +``` +位置(0,12) → token 12 +位置(1, 0) → token 13 +``` + +这两个 token 在一维序列中相邻(距离 1),但在二维地图上相距 12 列(横跨整行)。而真正的二维邻居关系: + +``` +位置(0, 0) 和 位置(1, 0) 的二维距离 = 1 格 +对应 token 0 和 token 13,一维距离 = 13 +``` + +一维位置嵌入告诉模型"token 0 和 token 13 相距较远",但实际上它们是相邻的竖向邻格。注意力机制的相对偏置完全依赖位置嵌入的初始化,无法从一维嵌入中自动推断二维邻接关系。 + +#### 为何 Stage1 特别敏感 + +Stage1 负责生成 floor/wall 骨架,这是整个地图中**空间结构约束最强**的层次: + +- 墙壁需要形成封闭或半封闭的房间边界; +- 走廊需要是连通的、宽度一致的通道; +- 整体拓扑(房间数、对称性、外围走廊)需要全局一致。 + +上述约束全部是**二维局部连通性约束**:一个墙壁格子是否合理,取决于它的上下左右四个邻格,而非它前后若干 token。一维位置编码使模型必须从数据中隐式学习这种行列边界,代价高昂且泛化差。 + +相比之下,Stage2/3 的任务(在走廊或房间内散布门/怪物/资源)对全局结构的空间一致性要求较低,位置编码的精确性影响较小。 + +--- + +## 改进方案 + +### 方案 A:二维因式分解位置嵌入(推荐首选) + +#### 思路 + +将当前单一的一维位置嵌入替换为行嵌入与列嵌入的加和: + +``` +pos_embed[i, j] = row_embed[i] + col_embed[j] +``` + +这样,同一行的所有格子共享相同的行嵌入,同一列的所有格子共享相同的列嵌入。模型可以直接从嵌入中感知行列身份,而无需从一维序号中隐式推断。 + +#### 具体实现 + +```python +# 替换原有的 pos_embedding +self.row_embedding = nn.Parameter(torch.randn(1, MAP_H, d_model) * 0.02) +self.col_embedding = nn.Parameter(torch.randn(1, MAP_W, d_model) * 0.02) +``` + +前向传播中: + +```python +# map: [B, H*W] +row_idx = torch.arange(MAP_H, device=map.device).repeat_interleave(MAP_W) +col_idx = torch.arange(MAP_W, device=map.device).repeat(MAP_H) +pos = self.row_embedding[0, row_idx] + self.col_embedding[0, col_idx] +x = self.tile_embedding(map) + pos.unsqueeze(0) +``` + +也可以预计算展开后的索引并缓存。 + +#### 特点 + +- 参数量变化:从 `169 × d_model` 变为 `13 × d_model + 13 × d_model = 26 × d_model`,显著减少,且参数共享有助于泛化; +- 无需修改注意力机制,改动最小; +- 直接赋予模型行列语义,改进效果立竿见影。 + +--- + +### 方案 B:二维相对位置偏置(推荐次选,与 A 叠加) + +#### 思路 + +在注意力计算中,对每对 query-key 的打分加入一个可学习偏置,偏置由两个 token 的**相对行列偏移量**决定: + +``` +score(i, j) = (q_i · k_j) / sqrt(d) + B[Δrow, Δcol] +``` + +其中 `Δrow = row(i) - row(j)`,`Δcol = col(i) - col(j)`。偏置表 B 的形状为 `(2H-1, 2W-1)` = `(25, 25)`,每个注意力头各一张。 + +这种方式的核心优势:注意力打分天然理解"相邻格子应更强相关",模型无需从位置嵌入中隐式学习距离感。 + +#### 具体实现 + +**步骤一:预计算相对位置索引表** + +对于 13×13 的地图,预计算每对 token (i, j) 的相对位置,将二维偏移量映射为一维索引,供后续从偏置表中 gather: + +```python +def build_relative_position_index(H: int, W: int) -> torch.Tensor: + coords_h = torch.arange(H) + coords_w = torch.arange(W) + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing='ij')) # [2, H, W] + coords_flat = coords.flatten(1) # [2, H*W] + rel = coords_flat[:, :, None] - coords_flat[:, None, :] # [2, H*W, H*W] + rel[0] += H - 1 + rel[1] += W - 1 + rel_index = rel[0] * (2 * W - 1) + rel[1] # [H*W, H*W] + return rel_index +``` + +**步骤二:在模型中注册偏置表** + +需要自定义 `SelfAttentionWithRPB`,替换 `nn.TransformerEncoderLayer` 中的注意力: + +```python +class SelfAttentionWithRPB(nn.Module): + def __init__(self, d_model: int, nhead: int, H: int, W: int): + super().__init__() + self.nhead = nhead + self.d_head = d_model // nhead + self.scale = self.d_head ** -0.5 + self.qkv = nn.Linear(d_model, d_model * 3) + self.out_proj = nn.Linear(d_model, d_model) + # 偏置表:(2H-1) * (2W-1) 个可能的相对位置,每个头各一组 + self.rel_bias_table = nn.Parameter( + torch.zeros(nhead, (2 * H - 1) * (2 * W - 1)) + ) + rel_index = build_relative_position_index(H, W) # [H*W, H*W] + self.register_buffer('rel_index', rel_index.flatten()) # [H*W * H*W] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.nhead, self.d_head) + q, k, v = qkv.unbind(2) # 各 [B, N, nhead, d_head] + q = q.permute(0, 2, 1, 3) # [B, nhead, N, d_head] + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale # [B, nhead, N, N] + + # 从偏置表中取出对应偏置,reshape 成 [nhead, N, N] + bias = self.rel_bias_table[:, self.rel_index].reshape(self.nhead, N, N) + attn = attn + bias.unsqueeze(0) + + attn = attn.softmax(dim=-1) + out = (attn @ v).permute(0, 2, 1, 3).reshape(B, N, C) + return self.out_proj(out) +``` + +**步骤三:替换 TransformerEncoderLayer 中的注意力** + +由于 PyTorch 标准 `TransformerEncoderLayer` 不支持直接替换注意力实现,需要手写包含 RPB 的编码器层: + +```python +class RPBEncoderLayer(nn.Module): + def __init__(self, d_model: int, nhead: int, dim_ff: int, H: int, W: int): + super().__init__() + self.norm1 = nn.LayerNorm(d_model) + self.attn = SelfAttentionWithRPB(d_model, nhead, H, W) + self.norm2 = nn.LayerNorm(d_model) + self.ffn = nn.Sequential( + nn.Linear(d_model, dim_ff), + nn.GELU(), + nn.Linear(dim_ff, d_model) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.norm1(x)) + x = x + self.ffn(self.norm2(x)) + return x +``` + +#### 参数量评估 + +对于 Stage1 MaskGIT(d_model=192, nhead=8, num_layers=6,H=W=13): + +- 偏置表:每层 `8 × 25 × 25 = 5000` 参数 +- 共 6 层:30,000 参数 +- 占总参数量的比例极低,但对 attention 的几何感知能力提升显著。 + +#### 特点 + +- 与方案 A 互补,可叠加使用; +- 理论上最接近"正确"的二维注意力感应偏置; +- 实现较复杂,需要手写注意力层; +- 参数量增加极少。 + +--- + +### 方案 C:轴向注意力(Axial Attention) + +#### 思路 + +将每一个标准自注意力层替换为两个顺序执行的注意力: + +1. **行轴注意力**:每一行内的格子互相注意,13 个 group,每组 13 个 token; +2. **列轴注意力**:每一列内的格子互相注意,13 个 group,每组 13 个 token。 + +两种注意力交替叠加:`Row → Col → Row → Col → ...` + +``` +标准自注意力:O((H*W)²) = O(169²) = 28,561 +轴向注意力: O(H * W²) + O(W * H²) = O(H*W*(H+W)) = O(169*26) ≈ 4,394 +``` + +复杂度大幅下降,且强制模型以行/列为单位建立空间关联。 + +#### 具体实现 + +```python +class AxialAttention(nn.Module): + def __init__(self, d_model: int, nhead: int, H: int, W: int): + super().__init__() + self.H = H + self.W = W + self.row_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.col_attn = nn.MultiheadAttention(d_model, nhead, batch_first=True) + self.norm_row = nn.LayerNorm(d_model) + self.norm_col = nn.LayerNorm(d_model) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, N, C = x.shape + H, W = self.H, self.W + x = x.reshape(B, H, W, C) + + # 行注意力:每行的 W 个 token 互相注意 + x_row = x.reshape(B * H, W, C) + x_row, _ = self.row_attn(x_row, x_row, x_row) + x = x + x_row.reshape(B, H, W, C) + x = self.norm_row(x) + + # 列注意力:每列的 H 个 token 互相注意 + x_col = x.permute(0, 2, 1, 3).reshape(B * W, H, C) + x_col, _ = self.col_attn(x_col, x_col, x_col) + x = x + x_col.reshape(B, W, H, C).permute(0, 2, 1, 3) + x = self.norm_col(x) + + return x.reshape(B, N, C) + + +class AxialEncoderLayer(nn.Module): + def __init__(self, d_model: int, nhead: int, dim_ff: int, H: int, W: int): + super().__init__() + self.axial_attn = AxialAttention(d_model, nhead, H, W) + self.norm_ff = nn.LayerNorm(d_model) + self.ffn = nn.Sequential( + nn.Linear(d_model, dim_ff), + nn.GELU(), + nn.Linear(dim_ff, d_model) + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.axial_attn(x) + x = x + self.ffn(self.norm_ff(x)) + return x +``` + +#### 特点 + +- 最强的二维结构归纳偏置:明确区分行内关系和列内关系; +- 改造最彻底,需要替换整个 Transformer 编码器; +- 计算量低于标准全局注意力; +- 不适合 cross-attention 部分(decoder 中 map token cross-attend z 的部分保持不变)。 + +--- + +### 方案 D:双轴并行输入(Dual-Axis Input) + +#### 思路 + +同一张地图展平两次,分别按行优先和列优先展平,两份序列并行送入各自的嵌入+编码器,最终在 token 维度相加合并: + +``` +地图 [H, W] + ├─ 行优先展平 → [H*W] → Embedding + 行位置编码 → 编码器 A → [H*W, d_model] + └─ 列优先展平 → [W*H] → Embedding + 列位置编码 → 编码器 B → [W*H, d_model] + ↓ 重排到行优先顺序 + 相加合并 → [H*W, d_model] → Decoder + z → logits +``` + +两个编码器可以共享权重(只有位置编码不同),以减少参数量。 + +#### 特点 + +- 无需修改注意力机制,复用现有 Transformer; +- 计算量加倍(两次编码),但可通过共享权重缓解; +- 同时为模型提供横向和纵向的序列上下文; +- 参数复用程度较高(共享编码器权重时)。 + +--- + +## 方案对比 + +| 方案 | 改动范围 | 额外参数量 | 二维感知能力 | 实现复杂度 | +| ------------------- | ------------ | ---------------------------------- | ------------ | ---------- | +| A:二维因式位置嵌入 | 位置嵌入层 | 减少(共享) | 中等 | 低 | +| B:二维相对位置偏置 | 注意力层 | 极少(~30k) | 强 | 中等 | +| C:轴向注意力 | 整个编码器 | 基本不变 | 最强 | 高 | +| D:双轴并行输入 | 输入与编码器 | 按编码器大小翻倍(共享权重则不变) | 中等 | 中等 | + +--- + +## 推荐实施策略 + +### 第一步:替换位置嵌入(方案 A) + +优先实施方案 A,这是改动最小、风险最低的基础改进。仅需修改 `GinkaMaskGIT` 和 `GinkaVQVAE` 中的 `pos_embedding` 初始化与使用方式。 + +修改范围: + +- `ginka/maskGIT/model.py`:替换 `pos_embedding`,调整 `forward` 中的位置嵌入加法; +- `ginka/vqvae/model.py`:同上。 + +### 第二步:叠加二维相对位置偏置(方案 B) + +在方案 A 的基础上,为 `GinkaMaskGIT` 的编码器部分叠加 RPB,只需新增一个自定义 `RPBEncoderLayer`,替换 `maskGIT.py` 中 `Transformer.encoder` 的层类型。 + +VQ-VAE 编码器的 RPB 改造优先级较低(VQ-VAE 负责全图压缩,对局部连通性感知需求低于 MaskGIT)。 + +### 可选第三步:轴向注意力替换(方案 C) + +若前两步改进后 Stage1 质量仍不满足要求,可进一步将 `GinkaMaskGIT` 的编码器改为轴向注意力。由于 decoder(cross-attention 部分)不涉及空间 token,无需改动。 + +### 其他注意事项 + +- 上述改动仅针对 **Stage1 MaskGIT**。Stage2/3 可以同步修改,也可以保持原结构,视实际效果决定; +- 改动后需重置训练,位置嵌入的结构变化会使旧检查点不兼容(shape 不匹配); +- 验证时重点观察生成地图中墙壁的连通性、房间闭合度和整体拓扑是否改善。 diff --git a/ginka/dataset.py b/ginka/dataset.py index 194a1af..3c9d9b9 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -33,11 +33,9 @@ class GinkaSeperatedDataset(Dataset): def __init__( self, data_path: str, - subset_weights: tuple = (0.5, 0.3, 0.2), - subset2_wall_prob: float = 0.7 + subset_weights: tuple = (0.5, 0.3, 0.2) ): self.data = load_data(data_path) - self.subset2_wall_prob = subset2_wall_prob total = sum(subset_weights) self.subset_cumw = [sum(subset_weights[:i+1]) / total for i in range(len(subset_weights))] @@ -134,8 +132,8 @@ class GinkaSeperatedDataset(Dataset): enc2 = inp2.copy() enc3 = raw.copy() - if np.random.random() < self.subset2_wall_prob: - inp1[self.std_mask()] = self.MASK_ID + need_mask = np.isin(inp2, [self.FLOOR, self.WALL]) + inp1[need_mask & self.std_mask()] = self.MASK_ID need_mask = np.isin(inp2, [self.FLOOR, self.DOOR, self.MONSTER, self.ENTRANCE]) inp2[need_mask] = self.MASK_ID need_mask = np.isin(inp3, [self.FLOOR, self.RESOURCE]) diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index d5244f3..b99ce12 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -13,22 +13,16 @@ OUTER_VOCAB = 2 # outerWall 0-1 class GinkaMaskGIT(nn.Module): def __init__( self, num_classes: int = 16, d_model: int = 192, dim_ff: int = 512, - nhead: int = 8, num_layers: int = 4, map_size: int = 13 * 13, d_z: int = 64 + nhead: int = 8, num_layers: int = 4, map_h: int = 13, map_w: int = 13, d_z: int = 64 ): - """ - Args: - num_classes: tile 类别数(含 MASK token=15) - d_model: Transformer 内部维度 - dim_ff: 前馈网络隐层维度 - nhead: 注意力头数 - num_layers: Transformer 层数 - map_size: 地图 token 总数(H * W) - """ super().__init__() - - # Tile 嵌入 + 位置编码 + self.map_h = map_h + self.map_w = map_w + + # Tile 嵌入 + 二维因式分解位置编码 self.tile_embedding = nn.Embedding(num_classes, d_model) - self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02) + self.row_embedding = nn.Parameter(torch.randn(1, map_h, d_model) * 0.02) + self.col_embedding = nn.Parameter(torch.randn(1, map_w, d_model) * 0.02) # z 投影:将 VQ 码字从 d_z 维映射到 d_model 维,供 cross-attention 使用 self.z_proj = nn.Sequential( @@ -92,8 +86,10 @@ class GinkaMaskGIT(nn.Module): z_mem = torch.cat([z_mem_vq, z_mem_struct], dim=1) # [B, L * 3 + 4, d_model] # tile embedding + 位置编码 - x = self.tile_embedding(map) # [B, H * W, d_model] - x = x + self.pos_embedding # [B, H * W, d_model] + row_idx = torch.arange(self.map_h, device=map.device).repeat_interleave(self.map_w) + col_idx = torch.arange(self.map_w, device=map.device).repeat(self.map_h) + pos = self.row_embedding[0, row_idx] + self.col_embedding[0, col_idx] # [H*W, d_model] + x = self.tile_embedding(map) + pos # [B, H * W, d_model] # Transformer:encoder 做 map 自注意力,decoder cross-attend z+struct x = self.transformer(x, memory=z_mem) # [B, H * W, d_model] @@ -120,7 +116,8 @@ if __name__ == "__main__": dim_ff=2048, nhead=8, num_layers=6, - map_size=13 * 13, + map_h=13, + map_w=13 ).to(device) print_memory(device, "初始化后") diff --git a/ginka/train_seperated.py b/ginka/train_seperated.py index 438d3a5..8e164ec 100644 --- a/ginka/train_seperated.py +++ b/ginka/train_seperated.py @@ -82,7 +82,6 @@ MAP_W = 13 # 地图宽度 MAP_H = 13 # 地图高度 MAP_SIZE = MAP_W * MAP_H # 地图大小 GENERATE_STEP = 18 # MaskGIT 采样步数 -SUBSET2_WALL_PROB = 0.7 # 子集2 进行墙壁掩码的概率 SUBSET_WEIGHTS = (0.5, 0.3, 0.2) # 每个子集的概率 MG_Z_DROPOUT = 0.1 # z 隐变量 Dropout 概率 @@ -127,8 +126,8 @@ def build_model(device: torch.device): # 三组 VQ-VAE 编码器:各自独立编码一个阶段的地图上下文(encoder_stage1/2/3) # 输出形状均为 [B, L, d_z],拼接后送入共用 quantizer vq_kwargs = dict( - num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_model=VQ_D_MODEL, - nhead=VQ_NHEAD, num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_size=MAP_SIZE + num_classes=NUM_CLASSES, L=VQ_L, K=VQ_K, d_model=VQ_D_MODEL, + nhead=VQ_NHEAD, num_layers=VQ_LAYERS, dim_ff=VQ_DIM_FF, map_h=MAP_H, map_w=MAP_W ) vq1 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage1 上下文(floor/wall) vq2 = GinkaVQVAE(**vq_kwargs).to(device) # 编码 stage2 上下文(door/monster/entrance) @@ -137,15 +136,15 @@ def build_model(device: torch.device): # 三个独立 MaskGIT 解码器,均接收完整的三阶段 z_q 作为条件 mg1 = GinkaMaskGIT( num_classes=NUM_CLASSES, d_model=STAGE1_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE1_MG_DIM_FF, - nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_size=MAP_SIZE + nhead=STAGE1_MG_NHEAD, num_layers=STAGE1_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W ).to(device) mg2 = GinkaMaskGIT( num_classes=NUM_CLASSES, d_model=STAGE2_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE2_MG_DIM_FF, - nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_size=MAP_SIZE + nhead=STAGE2_MG_NHEAD, num_layers=STAGE2_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W ).to(device) mg3 = GinkaMaskGIT( num_classes=NUM_CLASSES, d_model=STAGE3_MG_DMODEL, d_z=VQ_D_Z, dim_ff=STAGE3_MG_DIM_FF, - nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_size=MAP_SIZE + nhead=STAGE3_MG_NHEAD, num_layers=STAGE3_MG_NUM_LAYERS, map_h=MAP_H, map_w=MAP_W ).to(device) # 六个模型参数合并到同一优化器,端到端联合训练 @@ -527,6 +526,19 @@ def train(device: torch.device): models = build_model(device) vq1, vq2, vq3, mg1, mg2, mg3, quantizer, optimizer, scheduler = models + tqdm.write(f"Device: {device}") + model_list = [ + ("vq1", vq1), ("vq2", vq2), ("vq3", vq3), + ("mg1", mg1), ("mg2", mg2), ("mg3", mg3), + ("quantizer", quantizer) + ] + total_params = 0 + for name, m in model_list: + n = sum(p.numel() for p in m.parameters()) + total_params += n + tqdm.write(f"{name}: {n:,} params") + tqdm.write(f"Total: {total_params:,} params") + start_epoch = 0 if args.resume: @@ -550,14 +562,14 @@ def train(device: torch.device): os.makedirs("result/seperated", exist_ok=True) dataset = GinkaSeperatedDataset( - args.train, subset_weights=SUBSET_WEIGHTS, subset2_wall_prob=SUBSET2_WALL_PROB + args.train, subset_weights=SUBSET_WEIGHTS ) dataloader = DataLoader( dataset, batch_size=BATCH_SIZE, shuffle=True ) dataset_val = GinkaSeperatedDataset( - args.validate, subset_weights=SUBSET_WEIGHTS, subset2_wall_prob=SUBSET2_WALL_PROB + args.validate, subset_weights=SUBSET_WEIGHTS ) dataloader_val = DataLoader( dataset_val, batch_size=min(BATCH_SIZE, len(dataset_val) // 8), shuffle=True diff --git a/ginka/vqvae/model.py b/ginka/vqvae/model.py index 8d4842e..0fd04d1 100644 --- a/ginka/vqvae/model.py +++ b/ginka/vqvae/model.py @@ -100,17 +100,20 @@ class VQDecodeHead(nn.Module): class GinkaVQVAE(nn.Module): def __init__( self, num_classes: int = 16, L: int = 2, K: int = 16, d_z: int = 64, d_model: int = 128, - nhead: int = 4, num_layers: int = 2, dim_ff: int = 256, map_size: int = 13 * 13 + nhead: int = 4, num_layers: int = 2, dim_ff: int = 256, map_h: int = 13, map_w: int = 13 ): super().__init__() self.L = L self.K = K + self.map_h = map_h + self.map_w = map_w # Tile 嵌入 self.tile_embedding = nn.Embedding(num_classes, d_model) - # 地图位置编码(仅覆盖 map_size 个位置,不含 summary token) - self.pos_embedding = nn.Parameter(torch.randn(1, map_size, d_model) * 0.02) + # 二维因式分解位置编码:行嵌入 + 列嵌入,共享行列语义 + self.row_embedding = nn.Parameter(torch.randn(1, map_h, d_model) * 0.02) + self.col_embedding = nn.Parameter(torch.randn(1, map_w, d_model) * 0.02) # L 个可学习 summary token,拼接到序列头部 self.summary_tokens = nn.Parameter(torch.randn(1, L, d_model) * 0.02) @@ -128,13 +131,15 @@ class GinkaVQVAE(nn.Module): nn.LayerNorm(d_z), ) - def forward(self, map: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def forward(self, map: torch.Tensor) -> torch.Tensor: # map: [B, H * W] - B, L = map.shape + B, _ = map.shape - x = self.tile_embedding(map) # [B, H * W, d_model] - x = x + self.pos_embedding # [B, H * W, d_model] - + row_idx = torch.arange(self.map_h, device=map.device).repeat_interleave(self.map_w) + col_idx = torch.arange(self.map_w, device=map.device).repeat(self.map_h) + pos = self.row_embedding[0, row_idx] + self.col_embedding[0, col_idx] # [H*W, d_model] + + x = self.tile_embedding(map) + pos # [B, H * W, d_model] summary = self.summary_tokens.expand(B, -1, -1) # [B, L, d_model] x = torch.cat([summary, x], dim=1) # [B, L+H*W, d_model] @@ -151,7 +156,7 @@ if __name__ == "__main__": model = GinkaVQVAE( num_classes=7, L=2, K=16, d_z=64, d_model=128, - nhead=4, num_layers=2, dim_ff=256, map_size=13 * 13, + nhead=4, num_layers=2, dim_ff=256, map_h=13, map_w=13 ).to(device) print_memory(device, "初始化后")