From dd6a0434873b3845059f44a408e4443bbe4e7fac Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 20 Jan 2026 16:31:13 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20vae=E7=BC=96=E7=A0=81=E5=99=A8=E6=8D=A2?= =?UTF-8?q?=E4=B8=BA=20rnn?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/vae_rnn/decoder.py | 14 +++--- ginka/vae_rnn/encoder.py | 99 +++++++++++++++++++++++++++++++--------- 2 files changed, 84 insertions(+), 29 deletions(-) diff --git a/ginka/vae_rnn/decoder.py b/ginka/vae_rnn/decoder.py index a649829..d14f87e 100644 --- a/ginka/vae_rnn/decoder.py +++ b/ginka/vae_rnn/decoder.py @@ -84,12 +84,12 @@ class GinkaPosEmbedding(nn.Module): self.width = width self.height = height - self.row_embedding = nn.Embedding(width, embed_dim) - self.col_embedding = nn.Embedding(height, embed_dim) + self.row_embedding = nn.Embedding(height, embed_dim) + self.col_embedding = nn.Embedding(width, embed_dim) def forward(self, x: torch.Tensor, y: torch.Tensor): - row = self.row_embedding(x).squeeze(1) - col = self.col_embedding(y).squeeze(1) + row = self.row_embedding(y).squeeze(1) + col = self.col_embedding(x).squeeze(1) return row, col @@ -104,7 +104,7 @@ class GinkaInputFusion(nn.Module): d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True, dropout=0.2 ), - num_layers=4 + num_layers=3 ) def forward( @@ -220,13 +220,13 @@ if __name__ == "__main__": # 前向传播 start = time.perf_counter() - fake_logits, fake_map = model(map_vec, input, 0) + fake_logits = model(map_vec, input, 0) end = time.perf_counter() print_memory("前向传播后") print(f"推理耗时: {end - start}") - print(f"输出形状: fake_logits={fake_logits.shape}, fake_map={fake_map.shape}") + print(f"输出形状: fake_logits={fake_logits.shape}") print(f"Map Vector FC parameters: {sum(p.numel() for p in model.map_vec_fc.parameters())}") print(f"Tile Embedding parameters: {sum(p.numel() for p in model.tile_embedding.parameters())}") print(f"Position Embedding parameters: {sum(p.numel() for p in model.pos_embedding.parameters())}") diff --git a/ginka/vae_rnn/encoder.py b/ginka/vae_rnn/encoder.py index 419b1ba..57c907c 100644 --- a/ginka/vae_rnn/encoder.py +++ b/ginka/vae_rnn/encoder.py @@ -4,29 +4,84 @@ import torch.nn as nn import torch.nn.functional as F from ..utils import print_memory -class VAEEncoder(nn.Module): - def __init__(self, tile_classes=32, latent_dim=32): +class EncoderEmbedding(nn.Module): + def __init__(self, tile_classes=32, width=13, height=13, hidden_dim=128, output_dim=256): super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(tile_classes, 64, 3, padding=1), - nn.BatchNorm2d(64), - nn.ReLU(), + self.tile_embedding = nn.Embedding(tile_classes, hidden_dim) + self.col_embedding = nn.Embedding(width, hidden_dim) + self.row_embedding = nn.Embedding(height, hidden_dim) + self.fusion = nn.Linear(hidden_dim * 3, output_dim) + + def forward(self, tile, x, y): + tile_embed = self.tile_embedding(tile) + col_embed = self.col_embedding(x) + row_embed = self.row_embedding(y) + embed = torch.cat([tile_embed, col_embed, row_embed], dim=2) + fused = self.fusion(embed) + return fused + +class EncoderGRU(nn.Module): + def __init__(self, input_dim=256, hidden_dim=512, output_dim=256): + super().__init__() + + # GRU + self.gru = nn.GRUCell(input_dim, hidden_dim) + self.drop = nn.Dropout(0.1) + self.fc = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.GELU(), - nn.Conv2d(64, 128, 3, stride=2, padding=1), - nn.BatchNorm2d(128), - nn.ReLU(), - - nn.Conv2d(128, 256, 3, stride=2, padding=1), - nn.BatchNorm2d(256), - nn.ReLU(), - - nn.Flatten() + nn.Linear(hidden_dim, output_dim) ) - self.fc_mu = nn.Linear(256 * 4 * 4, latent_dim) - self.fc_logvar = nn.Linear(256 * 4 * 4, latent_dim) + + def forward(self, feat: torch.Tensor, hidden: torch.Tensor): + """ + feat: [B, input_dim] + hidden: [B, hidden_dim] + """ + hidden = self.drop(self.gru(feat, hidden)) + logits = self.fc(hidden) + return logits, hidden + +class VAEEncoder(nn.Module): + def __init__(self, device, tile_classes=32, latent_dim=32, width=13, height=13): + super().__init__() + self.device = device + + self.rnn_hidden = 512 + self.logits_dim = 256 + + self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256) + self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim) + self.fc_mu = nn.Linear(512, latent_dim) + self.fc_logvar = nn.Linear(512, latent_dim) + + self.col_list = [] + self.row_list = [] + for y in range(0, height): + for x in range(0, width): + self.col_list.append(x) + self.row_list.append(y) - def forward(self, x): - h = self.conv(x) + def forward(self, x: torch.Tensor): + B, H, W = x.shape + + map = torch.flatten(x, start_dim=1) + hidden = torch.zeros(B, self.rnn_hidden).to(self.device) + output = torch.zeros(B, H * W, self.logits_dim).to(self.device) + + col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1) + row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1) + embed = self.embedding(map, col_list, row_list) + + for idx in range(0, len(self.col_list)): + logits, h = self.rnn(embed[:, idx], hidden) + hidden = h + output[:, idx] = logits + h_mean = torch.mean(output, dim=1) + h_max = torch.max(output, dim=1).values + h = torch.cat([h_mean, h_max], dim=1) mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvar @@ -35,10 +90,9 @@ if __name__ == "__main__": device = torch.device("cpu") input = torch.randint(0, 32, [1, 13, 13]).to(device) - input = F.one_hot(input, 32).permute(0, 3, 1, 2).float() # 初始化模型 - model = VAEEncoder().to(device) + model = VAEEncoder(device).to(device) print_memory("初始化后") @@ -51,5 +105,6 @@ if __name__ == "__main__": print(f"推理耗时: {end - start}") print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}") - print(f"CNN parameters: {sum(p.numel() for p in model.conv.parameters())}") + print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}") + print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")