mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 调整 decoder 的位置编码
This commit is contained in:
parent
14f391f4f4
commit
4d244d021a
@ -86,12 +86,15 @@ class GinkaPosEmbedding(nn.Module):
|
||||
|
||||
self.row_embedding = nn.Embedding(height, embed_dim)
|
||||
self.col_embedding = nn.Embedding(width, embed_dim)
|
||||
self.fusion = nn.Linear(embed_dim * 2, embed_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
row = self.row_embedding(y).squeeze(1)
|
||||
col = self.col_embedding(x).squeeze(1)
|
||||
row = self.row_embedding(y)
|
||||
col = self.col_embedding(x)
|
||||
embed = torch.cat([row, col], dim=2)
|
||||
fused = self.fusion(embed)
|
||||
|
||||
return row, col
|
||||
return fused
|
||||
|
||||
class GinkaInputFusion(nn.Module):
|
||||
def __init__(self, d_model=256):
|
||||
@ -109,16 +112,15 @@ class GinkaInputFusion(nn.Module):
|
||||
|
||||
def forward(
|
||||
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
|
||||
row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor
|
||||
pos_embed: torch.Tensor, patch_vec: torch.Tensor
|
||||
):
|
||||
"""
|
||||
tile_embed: [B, 256]
|
||||
cond_vec: [B, 256]
|
||||
row_embed: [B, 256]
|
||||
col_embed: [B, 256]
|
||||
pos_embed: [B, 256]
|
||||
patch_vec: [B, 256]
|
||||
"""
|
||||
vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
|
||||
vec = torch.stack([tile_embed, cond_vec, pos_embed, patch_vec], dim=1)
|
||||
feat = self.transformer(vec)
|
||||
return feat[:, 0]
|
||||
|
||||
@ -168,6 +170,13 @@ class VAEDecoder(nn.Module):
|
||||
self.feat_fusion = GinkaInputFusion()
|
||||
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
|
||||
|
||||
self.col_list = []
|
||||
self.row_list = []
|
||||
for y in range(0, height):
|
||||
for x in range(0, width):
|
||||
self.col_list.append(x)
|
||||
self.row_list.append(y)
|
||||
|
||||
def forward(self, map_vec: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
|
||||
"""
|
||||
map_vec: [B, vec_dim]
|
||||
@ -183,19 +192,21 @@ class VAEDecoder(nn.Module):
|
||||
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device)
|
||||
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device)
|
||||
|
||||
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
|
||||
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
|
||||
pos_embed = self.pos_embedding(col_list, row_list)
|
||||
|
||||
map_vec = self.map_vec_fc(map_vec)
|
||||
|
||||
for y in range(0, self.height):
|
||||
for x in range(0, self.width):
|
||||
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
|
||||
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
|
||||
# 位置编码、图块编码、地图局部编码
|
||||
idx = y * self.width + x
|
||||
# 图块编码、地图局部编码
|
||||
tile_embed = self.tile_embedding(now_tile)
|
||||
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
|
||||
use_self = random.random() < use_self_probility
|
||||
map_patch = self.map_patch(map if use_self else target_map, x, y)
|
||||
# 编码特征融合
|
||||
feat = self.feat_fusion(tile_embed, map_vec, row_embed, col_embed, map_patch)
|
||||
feat = self.feat_fusion(tile_embed, map_vec, pos_embed[:, idx], map_patch)
|
||||
# RNN 输出
|
||||
logits, h = self.rnn(feat, hidden)
|
||||
# 处理输出
|
||||
|
||||
Loading…
Reference in New Issue
Block a user