diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index 8691daa..69cf75c 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -183,7 +183,7 @@ class GinkaRNNModel(nn.Module): x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1) y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1) # 位置编码、图块编码、地图局部编码 - tile_embed = self.tile_embedding(target_map[:, y, x]) + tile_embed = self.tile_embedding(now_tile) row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor) map_patch = self.map_patch(map, x, y) # 编码特征融合 @@ -192,11 +192,11 @@ class GinkaRNNModel(nn.Module): logits, h = self.rnn(feat, hidden) # 处理输出 output_logits[:, y, x] = logits[:] - hidden[:] = h[:] + hidden = h probs = F.softmax(logits, dim=1) tile_id = torch.argmax(probs, dim=1).detach() map[:, y, x] = tile_id[:] - # now_tile[:] = tile_id[:] + now_tile = tile_id if use_self else target_map[:, y, x].detach() return output_logits, map