feat: 调整 decoder 的位置编码

This commit is contained in:
unanmed 2026-01-20 16:50:11 +08:00
parent 14f391f4f4
commit 4d244d021a

View File

@ -86,12 +86,15 @@ class GinkaPosEmbedding(nn.Module):
self.row_embedding = nn.Embedding(height, embed_dim)
self.col_embedding = nn.Embedding(width, embed_dim)
self.fusion = nn.Linear(embed_dim * 2, embed_dim)
def forward(self, x: torch.Tensor, y: torch.Tensor):
row = self.row_embedding(y).squeeze(1)
col = self.col_embedding(x).squeeze(1)
row = self.row_embedding(y)
col = self.col_embedding(x)
embed = torch.cat([row, col], dim=2)
fused = self.fusion(embed)
return row, col
return fused
class GinkaInputFusion(nn.Module):
def __init__(self, d_model=256):
@ -109,16 +112,15 @@ class GinkaInputFusion(nn.Module):
def forward(
self, tile_embed: torch.Tensor, cond_vec: torch.Tensor,
row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor
pos_embed: torch.Tensor, patch_vec: torch.Tensor
):
"""
tile_embed: [B, 256]
cond_vec: [B, 256]
row_embed: [B, 256]
col_embed: [B, 256]
pos_embed: [B, 256]
patch_vec: [B, 256]
"""
vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1)
vec = torch.stack([tile_embed, cond_vec, pos_embed, patch_vec], dim=1)
feat = self.transformer(vec)
return feat[:, 0]
@ -168,6 +170,13 @@ class VAEDecoder(nn.Module):
self.feat_fusion = GinkaInputFusion()
self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden)
self.col_list = []
self.row_list = []
for y in range(0, height):
for x in range(0, width):
self.col_list.append(x)
self.row_list.append(y)
def forward(self, map_vec: torch.Tensor, target_map: torch.Tensor, use_self_probility=0):
"""
map_vec: [B, vec_dim]
@ -183,19 +192,21 @@ class VAEDecoder(nn.Module):
output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device)
hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device)
col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1)
row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1)
pos_embed = self.pos_embedding(col_list, row_list)
map_vec = self.map_vec_fc(map_vec)
for y in range(0, self.height):
for x in range(0, self.width):
x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1)
y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1)
# 位置编码、图块编码、地图局部编码
idx = y * self.width + x
# 图块编码、地图局部编码
tile_embed = self.tile_embedding(now_tile)
row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor)
use_self = random.random() < use_self_probility
map_patch = self.map_patch(map if use_self else target_map, x, y)
# 编码特征融合
feat = self.feat_fusion(tile_embed, map_vec, row_embed, col_embed, map_patch)
feat = self.feat_fusion(tile_embed, map_vec, pos_embed[:, idx], map_patch)
# RNN 输出
logits, h = self.rnn(feat, hidden)
# 处理输出