mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: inplace
This commit is contained in:
parent
771553b8b8
commit
27b8c56cd2
@ -183,7 +183,7 @@ class GinkaRNNModel(nn.Module):
|
||||
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
|
||||
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
|
||||
# 位置编码、图块编码、地图局部编码
|
||||
tile_embed = self.tile_embedding(target_map[:, y, x])
|
||||
tile_embed = self.tile_embedding(now_tile)
|
||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
||||
map_patch = self.map_patch(map, x, y)
|
||||
# 编码特征融合
|
||||
@ -192,11 +192,11 @@ class GinkaRNNModel(nn.Module):
|
||||
logits, h = self.rnn(feat, hidden)
|
||||
# 处理输出
|
||||
output_logits[:, y, x] = logits[:]
|
||||
hidden[:] = h[:]
|
||||
hidden = h
|
||||
probs = F.softmax(logits, dim=1)
|
||||
tile_id = torch.argmax(probs, dim=1).detach()
|
||||
map[:, y, x] = tile_id[:]
|
||||
# now_tile[:] = tile_id[:]
|
||||
now_tile = tile_id if use_self else target_map[:, y, x].detach()
|
||||
|
||||
return output_logits, map
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user