diff --git a/ginka/vae_rnn/decoder.py b/ginka/vae_rnn/decoder.py index d14f87e..05452dc 100644 --- a/ginka/vae_rnn/decoder.py +++ b/ginka/vae_rnn/decoder.py @@ -86,12 +86,15 @@ class GinkaPosEmbedding(nn.Module): self.row_embedding = nn.Embedding(height, embed_dim) self.col_embedding = nn.Embedding(width, embed_dim) + self.fusion = nn.Linear(embed_dim * 2, embed_dim) def forward(self, x: torch.Tensor, y: torch.Tensor): - row = self.row_embedding(y).squeeze(1) - col = self.col_embedding(x).squeeze(1) + row = self.row_embedding(y) + col = self.col_embedding(x) + embed = torch.cat([row, col], dim=2) + fused = self.fusion(embed) - return row, col + return fused class GinkaInputFusion(nn.Module): def __init__(self, d_model=256): @@ -109,16 +112,15 @@ class GinkaInputFusion(nn.Module): def forward( self, tile_embed: torch.Tensor, cond_vec: torch.Tensor, - row_embed: torch.Tensor, col_embed: torch.Tensor, patch_vec: torch.Tensor + pos_embed: torch.Tensor, patch_vec: torch.Tensor ): """ tile_embed: [B, 256] cond_vec: [B, 256] - row_embed: [B, 256] - col_embed: [B, 256] + pos_embed: [B, 256] patch_vec: [B, 256] """ - vec = torch.stack([tile_embed, cond_vec, row_embed, col_embed, patch_vec], dim=1) + vec = torch.stack([tile_embed, cond_vec, pos_embed, patch_vec], dim=1) feat = self.transformer(vec) return feat[:, 0] @@ -168,6 +170,13 @@ class VAEDecoder(nn.Module): self.feat_fusion = GinkaInputFusion() self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden) + self.col_list = [] + self.row_list = [] + for y in range(0, height): + for x in range(0, width): + self.col_list.append(x) + self.row_list.append(y) + def forward(self, map_vec: torch.Tensor, target_map: torch.Tensor, use_self_probility=0): """ map_vec: [B, vec_dim] @@ -183,19 +192,21 @@ class VAEDecoder(nn.Module): output_logits = torch.zeros([B, self.height, self.width, self.tile_classes]).to(self.device) hidden: torch.Tensor = torch.zeros(B, self.rnn_hidden).to(self.device) + col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1) + row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1) + pos_embed = self.pos_embedding(col_list, row_list) + map_vec = self.map_vec_fc(map_vec) for y in range(0, self.height): for x in range(0, self.width): - x_tensor = torch.LongTensor([x]).to(self.device).expand(B, -1) - y_tensor = torch.LongTensor([y]).to(self.device).expand(B, -1) - # 位置编码、图块编码、地图局部编码 + idx = y * self.width + x + # 图块编码、地图局部编码 tile_embed = self.tile_embedding(now_tile) - row_embed, col_embed = self.pos_embedding(x_tensor, y_tensor) use_self = random.random() < use_self_probility map_patch = self.map_patch(map if use_self else target_map, x, y) # 编码特征融合 - feat = self.feat_fusion(tile_embed, map_vec, row_embed, col_embed, map_patch) + feat = self.feat_fusion(tile_embed, map_vec, pos_embed[:, idx], map_patch) # RNN 输出 logits, h = self.rnn(feat, hidden) # 处理输出