mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 20:41:12 +08:00
feat: 训练时使用 target_map 做局部提取而不是自己输出的内容
This commit is contained in:
parent
d42c2eee43
commit
4d1179bcfb
@ -61,6 +61,10 @@ class GinkaMapPatch(nn.Module):
|
|||||||
res_bottom = 5
|
res_bottom = 5
|
||||||
|
|
||||||
result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right]
|
result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right]
|
||||||
|
# 没画到的地方要置为 0
|
||||||
|
result[:, 4, 2] = 0
|
||||||
|
result[:, 4, 3] = 0
|
||||||
|
result[:, 4, 4] = 0
|
||||||
result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
||||||
|
|
||||||
return self.patch_cnn(result)
|
return self.patch_cnn(result)
|
||||||
@ -185,7 +189,7 @@ class GinkaRNNModel(nn.Module):
|
|||||||
# 位置编码、图块编码、地图局部编码
|
# 位置编码、图块编码、地图局部编码
|
||||||
tile_embed = self.tile_embedding(now_tile)
|
tile_embed = self.tile_embedding(now_tile)
|
||||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
||||||
map_patch = self.map_patch(map, x, y)
|
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
||||||
# 编码特征融合
|
# 编码特征融合
|
||||||
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)
|
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)
|
||||||
# RNN 输出
|
# RNN 输出
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user