diff --git a/ginka/train_pretrain_split.py b/ginka/train_pretrain_split.py index 9841306..25b3871 100644 --- a/ginka/train_pretrain_split.py +++ b/ginka/train_pretrain_split.py @@ -145,7 +145,7 @@ def validate( # 通道 1 z_q1, _, idx1, _, _, _ = enc1(s1) - logits1 = head1(z_q1) + logits1 = head1(z_q1, torch.zeros_like(raw_map)) pred1 = logits1.argmax(dim=-1) # [B, H*W] wall_m = (raw_map == 1) ch1_tp += (pred1[wall_m] == 1).sum().item() @@ -155,7 +155,7 @@ def validate( # 通道 2 z_q2, _, idx2, _, _, _ = enc2(s2) - logits2 = head2(z_q2) + logits2 = head2(z_q2, s1) pred2 = logits2.argmax(dim=-1) for t in CH2_LOSS: m = (raw_map == t) @@ -166,7 +166,7 @@ def validate( # 通道 3 z_q3, _, idx3, _, _, _ = enc3(s3) - logits3 = head3(z_q3) + logits3 = head3(z_q3, s2) pred3 = logits3.argmax(dim=-1) for t in CH3_LOSS: m = (raw_map == t) @@ -294,19 +294,19 @@ def train(): # ─── 通道 1 ─── z_q1, _, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1) - logits1 = head1(z_q1) # [B, H*W, C] + logits1 = head1(z_q1, torch.zeros_like(raw_map)) # [B, H*W, C] fl1 = masked_focal(logits1, raw_map, CH1_LOSS, gamma=FOCAL_GAMMA) loss1 = fl1 + VQ_BETA * commit_loss1 + VQ_GAMMA * entropy_loss1 # ─── 通道 2 ─── z_q2, _, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2) - logits2 = head2(z_q2) + logits2 = head2(z_q2, s1) fl2 = masked_focal(logits2, raw_map, CH2_LOSS, gamma=FOCAL_GAMMA) loss2 = fl2 + VQ_BETA * commit_loss2 + VQ_GAMMA * entropy_loss2 # ─── 通道 3 ─── z_q3, _, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3) - logits3 = head3(z_q3) + logits3 = head3(z_q3, s2) fl3 = masked_focal(logits3, raw_map, CH3_LOSS, gamma=FOCAL_GAMMA) loss3 = fl3 + VQ_BETA * commit_loss3 + VQ_GAMMA * entropy_loss3 diff --git a/ginka/vqvae/model.py b/ginka/vqvae/model.py index a1ab105..ad2d56d 100644 --- a/ginka/vqvae/model.py +++ b/ginka/vqvae/model.py @@ -65,6 +65,9 @@ class VQDecodeHead(nn.Module): # 每个格子一个可学习位置查询 self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02) + # 条件地图嵌入:将切片地图 tile ID 映射到 d_z 空间,叠加到位置查询 + self.cond_embedding = nn.Embedding(num_classes, d_z) + # 堆叠多层解码块 self.layers = nn.ModuleList([ _DecodeLayer(d_z, nhead, dim_ff) for _ in range(num_layers) @@ -73,16 +76,20 @@ class VQDecodeHead(nn.Module): self.norm_out = nn.LayerNorm(d_z) self.classifier = nn.Linear(d_z, num_classes) - def forward(self, z_q: torch.Tensor) -> torch.Tensor: + def forward(self, z_q: torch.Tensor, cond_map: torch.Tensor | None = None) -> torch.Tensor: """ Args: - z_q: [B, L, d_z] + z_q: [B, L, d_z] 量化后的 z 向量 + cond_map: [B, map_size] 条件切片地图(整数 tile ID); + 为 None 时退化为纯位置查询(与旧行为一致) Returns: logits: [B, map_size, num_classes] """ B = z_q.shape[0] x = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z] + if cond_map is not None: + x = x + self.cond_embedding(cond_map) # 叠加切片上下文 for layer in self.layers: x = layer(x, z_q) x = self.norm_out(x)