feat: 预训练加入前一阶段的信息

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-05-06 22:51:57 +08:00
parent 7535ecc9fe
commit 3b86a78e0b
2 changed files with 15 additions and 8 deletions

View File

@ -145,7 +145,7 @@ def validate(
# 通道 1 # 通道 1
z_q1, _, idx1, _, _, _ = enc1(s1) z_q1, _, idx1, _, _, _ = enc1(s1)
logits1 = head1(z_q1) logits1 = head1(z_q1, torch.zeros_like(raw_map))
pred1 = logits1.argmax(dim=-1) # [B, H*W] pred1 = logits1.argmax(dim=-1) # [B, H*W]
wall_m = (raw_map == 1) wall_m = (raw_map == 1)
ch1_tp += (pred1[wall_m] == 1).sum().item() ch1_tp += (pred1[wall_m] == 1).sum().item()
@ -155,7 +155,7 @@ def validate(
# 通道 2 # 通道 2
z_q2, _, idx2, _, _, _ = enc2(s2) z_q2, _, idx2, _, _, _ = enc2(s2)
logits2 = head2(z_q2) logits2 = head2(z_q2, s1)
pred2 = logits2.argmax(dim=-1) pred2 = logits2.argmax(dim=-1)
for t in CH2_LOSS: for t in CH2_LOSS:
m = (raw_map == t) m = (raw_map == t)
@ -166,7 +166,7 @@ def validate(
# 通道 3 # 通道 3
z_q3, _, idx3, _, _, _ = enc3(s3) z_q3, _, idx3, _, _, _ = enc3(s3)
logits3 = head3(z_q3) logits3 = head3(z_q3, s2)
pred3 = logits3.argmax(dim=-1) pred3 = logits3.argmax(dim=-1)
for t in CH3_LOSS: for t in CH3_LOSS:
m = (raw_map == t) m = (raw_map == t)
@ -294,19 +294,19 @@ def train():
# ─── 通道 1 ─── # ─── 通道 1 ───
z_q1, _, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1) z_q1, _, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1)
logits1 = head1(z_q1) # [B, H*W, C] logits1 = head1(z_q1, torch.zeros_like(raw_map)) # [B, H*W, C]
fl1 = masked_focal(logits1, raw_map, CH1_LOSS, gamma=FOCAL_GAMMA) fl1 = masked_focal(logits1, raw_map, CH1_LOSS, gamma=FOCAL_GAMMA)
loss1 = fl1 + VQ_BETA * commit_loss1 + VQ_GAMMA * entropy_loss1 loss1 = fl1 + VQ_BETA * commit_loss1 + VQ_GAMMA * entropy_loss1
# ─── 通道 2 ─── # ─── 通道 2 ───
z_q2, _, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2) z_q2, _, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2)
logits2 = head2(z_q2) logits2 = head2(z_q2, s1)
fl2 = masked_focal(logits2, raw_map, CH2_LOSS, gamma=FOCAL_GAMMA) fl2 = masked_focal(logits2, raw_map, CH2_LOSS, gamma=FOCAL_GAMMA)
loss2 = fl2 + VQ_BETA * commit_loss2 + VQ_GAMMA * entropy_loss2 loss2 = fl2 + VQ_BETA * commit_loss2 + VQ_GAMMA * entropy_loss2
# ─── 通道 3 ─── # ─── 通道 3 ───
z_q3, _, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3) z_q3, _, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3)
logits3 = head3(z_q3) logits3 = head3(z_q3, s2)
fl3 = masked_focal(logits3, raw_map, CH3_LOSS, gamma=FOCAL_GAMMA) fl3 = masked_focal(logits3, raw_map, CH3_LOSS, gamma=FOCAL_GAMMA)
loss3 = fl3 + VQ_BETA * commit_loss3 + VQ_GAMMA * entropy_loss3 loss3 = fl3 + VQ_BETA * commit_loss3 + VQ_GAMMA * entropy_loss3

View File

@ -65,6 +65,9 @@ class VQDecodeHead(nn.Module):
# 每个格子一个可学习位置查询 # 每个格子一个可学习位置查询
self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02) self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02)
# 条件地图嵌入:将切片地图 tile ID 映射到 d_z 空间,叠加到位置查询
self.cond_embedding = nn.Embedding(num_classes, d_z)
# 堆叠多层解码块 # 堆叠多层解码块
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
_DecodeLayer(d_z, nhead, dim_ff) for _ in range(num_layers) _DecodeLayer(d_z, nhead, dim_ff) for _ in range(num_layers)
@ -73,16 +76,20 @@ class VQDecodeHead(nn.Module):
self.norm_out = nn.LayerNorm(d_z) self.norm_out = nn.LayerNorm(d_z)
self.classifier = nn.Linear(d_z, num_classes) self.classifier = nn.Linear(d_z, num_classes)
def forward(self, z_q: torch.Tensor) -> torch.Tensor: def forward(self, z_q: torch.Tensor, cond_map: torch.Tensor | None = None) -> torch.Tensor:
""" """
Args: Args:
z_q: [B, L, d_z] z_q: [B, L, d_z] 量化后的 z 向量
cond_map: [B, map_size] 条件切片地图整数 tile ID
None 时退化为纯位置查询与旧行为一致
Returns: Returns:
logits: [B, map_size, num_classes] logits: [B, map_size, num_classes]
""" """
B = z_q.shape[0] B = z_q.shape[0]
x = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z] x = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
if cond_map is not None:
x = x + self.cond_embedding(cond_map) # 叠加切片上下文
for layer in self.layers: for layer in self.layers:
x = layer(x, z_q) x = layer(x, z_q)
x = self.norm_out(x) x = self.norm_out(x)