fix: inplace

This commit is contained in:
unanmed 2025-12-13 17:48:51 +08:00
parent 771553b8b8
commit 27b8c56cd2

View File

@ -183,7 +183,7 @@ class GinkaRNNModel(nn.Module):
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
# 位置编码、图块编码、地图局部编码
tile_embed = self.tile_embedding(target_map[:, y, x])
tile_embed = self.tile_embedding(now_tile)
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
map_patch = self.map_patch(map, x, y)
# 编码特征融合
@ -192,11 +192,11 @@ class GinkaRNNModel(nn.Module):
logits, h = self.rnn(feat, hidden)
# 处理输出
output_logits[:, y, x] = logits[:]
hidden[:] = h[:]
hidden = h
probs = F.softmax(logits, dim=1)
tile_id = torch.argmax(probs, dim=1).detach()
map[:, y, x] = tile_id[:]
# now_tile[:] = tile_id[:]
now_tile = tile_id if use_self else target_map[:, y, x].detach()
return output_logits, map