From 4d1179bcfb35708620da66fd9d39349fa141f280 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sat, 13 Dec 2025 19:57:06 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=AE=AD=E7=BB=83=E6=97=B6=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=20target=5Fmap=20=E5=81=9A=E5=B1=80=E9=83=A8=E6=8F=90?= =?UTF-8?q?=E5=8F=96=E8=80=8C=E4=B8=8D=E6=98=AF=E8=87=AA=E5=B7=B1=E8=BE=93?= =?UTF-8?q?=E5=87=BA=E7=9A=84=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/generator/rnn.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index 1c1ff92..ef83f52 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -61,6 +61,10 @@ class GinkaMapPatch(nn.Module): res_bottom = 5 result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right] + # 没画到的地方要置为 0 + result[:, 4, 2] = 0 + result[:, 4, 3] = 0 + result[:, 4, 4] = 0 result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float() return self.patch_cnn(result) @@ -185,7 +189,7 @@ class GinkaRNNModel(nn.Module): # 位置编码、图块编码、地图局部编码 tile_embed = self.tile_embedding(now_tile) row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor) - map_patch = self.map_patch(map, x, y) + map_patch = self.map_patch(map if use_self else target_map, x, y) # 编码特征融合 feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch) # RNN 输出