mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 00:01:13 +08:00
feat: 训练时使用 target_map 做局部提取而不是自己输出的内容
This commit is contained in:
parent
d42c2eee43
commit
4d1179bcfb
@ -61,6 +61,10 @@ class GinkaMapPatch(nn.Module):
|
||||
res_bottom = 5
|
||||
|
||||
result[:, res_top:res_bottom, res_left:res_right] = map[:, top:bottom, left:right]
|
||||
# 没画到的地方要置为 0
|
||||
result[:, 4, 2] = 0
|
||||
result[:, 4, 3] = 0
|
||||
result[:, 4, 4] = 0
|
||||
result = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
||||
|
||||
return self.patch_cnn(result)
|
||||
@ -185,7 +189,7 @@ class GinkaRNNModel(nn.Module):
|
||||
# 位置编码、图块编码、地图局部编码
|
||||
tile_embed = self.tile_embedding(now_tile)
|
||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
||||
map_patch = self.map_patch(map, x, y)
|
||||
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
||||
# 编码特征融合
|
||||
feat = self.feat_fusion(tile_embed, cond, row_embed, col_embed, map_patch)
|
||||
# RNN 输出
|
||||
|
||||
Loading…
Reference in New Issue
Block a user