mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 预训练加入前一阶段的信息
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
7535ecc9fe
commit
3b86a78e0b
@ -145,7 +145,7 @@ def validate(
|
||||
|
||||
# 通道 1
|
||||
z_q1, _, idx1, _, _, _ = enc1(s1)
|
||||
logits1 = head1(z_q1)
|
||||
logits1 = head1(z_q1, torch.zeros_like(raw_map))
|
||||
pred1 = logits1.argmax(dim=-1) # [B, H*W]
|
||||
wall_m = (raw_map == 1)
|
||||
ch1_tp += (pred1[wall_m] == 1).sum().item()
|
||||
@ -155,7 +155,7 @@ def validate(
|
||||
|
||||
# 通道 2
|
||||
z_q2, _, idx2, _, _, _ = enc2(s2)
|
||||
logits2 = head2(z_q2)
|
||||
logits2 = head2(z_q2, s1)
|
||||
pred2 = logits2.argmax(dim=-1)
|
||||
for t in CH2_LOSS:
|
||||
m = (raw_map == t)
|
||||
@ -166,7 +166,7 @@ def validate(
|
||||
|
||||
# 通道 3
|
||||
z_q3, _, idx3, _, _, _ = enc3(s3)
|
||||
logits3 = head3(z_q3)
|
||||
logits3 = head3(z_q3, s2)
|
||||
pred3 = logits3.argmax(dim=-1)
|
||||
for t in CH3_LOSS:
|
||||
m = (raw_map == t)
|
||||
@ -294,19 +294,19 @@ def train():
|
||||
|
||||
# ─── 通道 1 ───
|
||||
z_q1, _, _, vq_loss1, commit_loss1, entropy_loss1 = enc1(s1)
|
||||
logits1 = head1(z_q1) # [B, H*W, C]
|
||||
logits1 = head1(z_q1, torch.zeros_like(raw_map)) # [B, H*W, C]
|
||||
fl1 = masked_focal(logits1, raw_map, CH1_LOSS, gamma=FOCAL_GAMMA)
|
||||
loss1 = fl1 + VQ_BETA * commit_loss1 + VQ_GAMMA * entropy_loss1
|
||||
|
||||
# ─── 通道 2 ───
|
||||
z_q2, _, _, vq_loss2, commit_loss2, entropy_loss2 = enc2(s2)
|
||||
logits2 = head2(z_q2)
|
||||
logits2 = head2(z_q2, s1)
|
||||
fl2 = masked_focal(logits2, raw_map, CH2_LOSS, gamma=FOCAL_GAMMA)
|
||||
loss2 = fl2 + VQ_BETA * commit_loss2 + VQ_GAMMA * entropy_loss2
|
||||
|
||||
# ─── 通道 3 ───
|
||||
z_q3, _, _, vq_loss3, commit_loss3, entropy_loss3 = enc3(s3)
|
||||
logits3 = head3(z_q3)
|
||||
logits3 = head3(z_q3, s2)
|
||||
fl3 = masked_focal(logits3, raw_map, CH3_LOSS, gamma=FOCAL_GAMMA)
|
||||
loss3 = fl3 + VQ_BETA * commit_loss3 + VQ_GAMMA * entropy_loss3
|
||||
|
||||
|
||||
@ -65,6 +65,9 @@ class VQDecodeHead(nn.Module):
|
||||
# 每个格子一个可学习位置查询
|
||||
self.pos_queries = nn.Parameter(torch.randn(1, map_size, d_z) * 0.02)
|
||||
|
||||
# 条件地图嵌入:将切片地图 tile ID 映射到 d_z 空间,叠加到位置查询
|
||||
self.cond_embedding = nn.Embedding(num_classes, d_z)
|
||||
|
||||
# 堆叠多层解码块
|
||||
self.layers = nn.ModuleList([
|
||||
_DecodeLayer(d_z, nhead, dim_ff) for _ in range(num_layers)
|
||||
@ -73,16 +76,20 @@ class VQDecodeHead(nn.Module):
|
||||
self.norm_out = nn.LayerNorm(d_z)
|
||||
self.classifier = nn.Linear(d_z, num_classes)
|
||||
|
||||
def forward(self, z_q: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self, z_q: torch.Tensor, cond_map: torch.Tensor | None = None) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
z_q: [B, L, d_z]
|
||||
z_q: [B, L, d_z] 量化后的 z 向量
|
||||
cond_map: [B, map_size] 条件切片地图(整数 tile ID);
|
||||
为 None 时退化为纯位置查询(与旧行为一致)
|
||||
|
||||
Returns:
|
||||
logits: [B, map_size, num_classes]
|
||||
"""
|
||||
B = z_q.shape[0]
|
||||
x = self.pos_queries.expand(B, -1, -1) # [B, map_size, d_z]
|
||||
if cond_map is not None:
|
||||
x = x + self.cond_embedding(cond_map) # 叠加切片上下文
|
||||
for layer in self.layers:
|
||||
x = layer(x, z_q)
|
||||
x = self.norm_out(x)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user