diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index 1c1ff92..ef83f52 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -61,6 +61,10 @@ class GinkaMapPatch(nn.Module): res_bottom = 5 result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right] + # 没画到的地方要置为 0 + result[:, 4, 2] = 0 + result[:, 4, 3] = 0 + result[:, 4, 4] = 0 result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float() return self.patch_cnn(result) @@ -185,7 +189,7 @@ class GinkaRNNModel(nn.Module): # 位置编码、图块编码、地图局部编码 tile_embed = self.tile_embedding(now_tile) row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor) - map_patch = self.map_patch(map, x, y) + map_patch = self.map_patch(map if use_self else target_map, x, y) # 编码特征融合 feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch) # RNN 输出