feat: 训练时使用 target_map 做局部提取而不是自己输出的内容

This commit is contained in:
unanmed 2025-12-13 19:57:06 +08:00
parent d42c2eee43
commit 4d1179bcfb

View File

@ -61,6 +61,10 @@ class GinkaMapPatch(nn.Module):
res_bottom = 5
result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right]
# 没画到的地方要置为 0
result[:, 4, 2] = 0
result[:, 4, 3] = 0
result[:, 4, 4] = 0
result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
return self.patch_cnn(result)
@ -185,7 +189,7 @@ class GinkaRNNModel(nn.Module):
# 位置编码、图块编码、地图局部编码
tile_embed = self.tile_embedding(now_tile)
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
map_patch = self.map_patch(map, x, y)
map_patch = self.map_patch(map if use_self else target_map, x, y)
# 编码特征融合
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)
# RNN 输出