mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-20 17:22:41 +08:00
fix: inplace
This commit is contained in:
parent
771553b8b8
commit
27b8c56cd2
@ -183,7 +183,7 @@ class GinkaRNNModel(nn.Module):
|
|||||||
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
|
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
|
||||||
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
|
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
|
||||||
# 位置编码、图块编码、地图局部编码
|
# 位置编码、图块编码、地图局部编码
|
||||||
tile_embed = self.tile_embedding(target_map[:, y, x])
|
tile_embed = self.tile_embedding(now_tile)
|
||||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
||||||
map_patch = self.map_patch(map, x, y)
|
map_patch = self.map_patch(map, x, y)
|
||||||
# 编码特征融合
|
# 编码特征融合
|
||||||
@ -192,11 +192,11 @@ class GinkaRNNModel(nn.Module):
|
|||||||
logits, h = self.rnn(feat, hidden)
|
logits, h = self.rnn(feat, hidden)
|
||||||
# 处理输出
|
# 处理输出
|
||||||
output_logits[:, y, x] = logits[:]
|
output_logits[:, y, x] = logits[:]
|
||||||
hidden[:] = h[:]
|
hidden = h
|
||||||
probs = F.softmax(logits, dim=1)
|
probs = F.softmax(logits, dim=1)
|
||||||
tile_id = torch.argmax(probs, dim=1).detach()
|
tile_id = torch.argmax(probs, dim=1).detach()
|
||||||
map[:, y, x] = tile_id[:]
|
map[:, y, x] = tile_id[:]
|
||||||
# now_tile[:] = tile_id[:]
|
now_tile = tile_id if use_self else target_map[:, y, x].detach()
|
||||||
|
|
||||||
return output_logits, map
|
return output_logits, map
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user