From b7fe24ee4c8a29b89328b408a3f52f78a1cf57a3 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 20 Jan 2026 23:48:25 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E7=BC=96=E7=A0=81=E5=99=A8=E5=8A=A0?= =?UTF-8?q?=E5=85=A5=20transformer=20=E6=95=B4=E5=90=88=EF=BC=8C=E6=94=B9?= =?UTF-8?q?=E8=BF=9B=E8=A7=A3=E7=A0=81=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vae.py | 1 + ginka/vae_rnn/decoder.py | 58 +++++++++++++++++++++------------------- ginka/vae_rnn/encoder.py | 44 +++++++++++++++++++++++++----- ginka/vae_rnn/vae.py | 27 ++++++++++++++++++- 4 files changed, 96 insertions(+), 34 deletions(-) diff --git a/ginka/train_vae.py b/ginka/train_vae.py index df72c47..f0c6969 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -124,6 +124,7 @@ def train(): loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05) loss.backward() + torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=2.0) optimizer_ginka.step() loss_total += loss.detach() reco_loss_total += reco_loss.detach() diff --git a/ginka/vae_rnn/decoder.py b/ginka/vae_rnn/decoder.py index 05452dc..6a9c679 100644 --- a/ginka/vae_rnn/decoder.py +++ b/ginka/vae_rnn/decoder.py @@ -5,7 +5,7 @@ import torch.nn as nn import torch.nn.functional as F from ..utils import print_memory -class GinkaMapPatch(nn.Module): +class DecoderMapPatch(nn.Module): def __init__(self, tile_classes=32, width=13, height=13): super().__init__() @@ -64,18 +64,16 @@ class GinkaMapPatch(nn.Module): feat = self.fc(feat) return feat -class GinkaTileEmbedding(nn.Module): +class DecoderTileEmbedding(nn.Module): def __init__(self, tile_classes=32, embed_dim=256): super().__init__() - # 图块编码,上一次画的图块 - self.embedding = nn.Embedding(tile_classes, embed_dim) def forward(self, tile: torch.Tensor): return self.embedding(tile) -class GinkaPosEmbedding(nn.Module): +class DecoderPosEmbedding(nn.Module): def __init__(self, width=13, height=13, embed_dim=256): super().__init__() @@ -86,17 +84,13 @@ class GinkaPosEmbedding(nn.Module): self.row_embedding = nn.Embedding(height, embed_dim) self.col_embedding = nn.Embedding(width, embed_dim) - self.fusion = nn.Linear(embed_dim * 2, embed_dim) def forward(self, x: torch.Tensor, y: torch.Tensor): row = self.row_embedding(y) col = self.col_embedding(x) - embed = torch.cat([row, col], dim=2) - fused = self.fusion(embed) - - return fused + return row, col -class GinkaInputFusion(nn.Module): +class DecoderInputFusion(nn.Module): def __init__(self, d_model=256): super().__init__() @@ -107,24 +101,31 @@ class GinkaInputFusion(nn.Module): d_model=d_model, nhead=2, dim_feedforward=d_model, batch_first=True, dropout=0.2 ), - num_layers=3 + num_layers=2 ) + self.norm = nn.LayerNorm(d_model) + self.fusion = nn.Linear(d_model * 2, d_model) def forward( self, tile_embed: torch.Tensor, cond_vec: torch.Tensor, - pos_embed: torch.Tensor, patch_vec: torch.Tensor + col_embed: torch.Tensor, row_embed: torch.Tensor, patch_vec: torch.Tensor ): """ tile_embed: [B, 256] cond_vec: [B, 256] - pos_embed: [B, 256] + col_embed: [B, 256] + row_embed: [B, 256] patch_vec: [B, 256] """ - vec = torch.stack([tile_embed, cond_vec, pos_embed, patch_vec], dim=1) - feat = self.transformer(vec) - return feat[:, 0] + vec = torch.stack([tile_embed, cond_vec, col_embed, row_embed, patch_vec], dim=1) + feat = self.norm(self.transformer(vec)) + mean = torch.mean(feat, dim=1) + max = torch.max(feat, dim=1).values + hidden = torch.cat([mean, max], dim=1) + fused = self.fusion(hidden) + return fused -class GinkaRNN(nn.Module): +class DecoderRNN(nn.Module): def __init__(self, tile_classes=32, input_dim=256, hidden_dim=512): super().__init__() @@ -162,13 +163,16 @@ class VAEDecoder(nn.Module): # 模型结构 self.map_vec_fc = nn.Sequential( - nn.Linear(map_vec_dim, 256) + nn.Linear(map_vec_dim, 128), + nn.LayerNorm(128), + nn.GELU(), + nn.Linear(128, 256) ) - self.tile_embedding = GinkaTileEmbedding(tile_classes=self.tile_classes) - self.pos_embedding = GinkaPosEmbedding() - self.map_patch = GinkaMapPatch(tile_classes=self.tile_classes) - self.feat_fusion = GinkaInputFusion() - self.rnn = GinkaRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden) + self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes) + self.pos_embedding = DecoderPosEmbedding() + self.map_patch = DecoderMapPatch(tile_classes=self.tile_classes) + self.feat_fusion = DecoderInputFusion() + self.rnn = DecoderRNN(tile_classes=self.tile_classes, hidden_dim=self.rnn_hidden) self.col_list = [] self.row_list = [] @@ -181,7 +185,7 @@ class VAEDecoder(nn.Module): """ map_vec: [B, vec_dim] target_map: [B, H, W] - use_self: 是否使用自己生成的上一步结果执行下一步 + use_self_probility: 使用自己生成的上一步结果执行下一步的概率 """ B, C = map_vec.shape @@ -194,7 +198,7 @@ class VAEDecoder(nn.Module): col_list = torch.IntTensor(self.col_list).to(self.device).expand(B, -1) row_list = torch.IntTensor(self.row_list).to(self.device).expand(B, -1) - pos_embed = self.pos_embedding(col_list, row_list) + col_embed, row_embed = self.pos_embedding(col_list, row_list) map_vec = self.map_vec_fc(map_vec) @@ -206,7 +210,7 @@ class VAEDecoder(nn.Module): use_self = random.random() < use_self_probility map_patch = self.map_patch(map if use_self else target_map, x, y) # 编码特征融合 - feat = self.feat_fusion(tile_embed, map_vec, pos_embed[:, idx], map_patch) + feat = self.feat_fusion(tile_embed, map_vec, col_embed[:, idx], row_embed[:, idx], map_patch) # RNN 输出 logits, h = self.rnn(feat, hidden) # 处理输出 diff --git a/ginka/vae_rnn/encoder.py b/ginka/vae_rnn/encoder.py index 57c907c..624976b 100644 --- a/ginka/vae_rnn/encoder.py +++ b/ginka/vae_rnn/encoder.py @@ -7,9 +7,21 @@ from ..utils import print_memory class EncoderEmbedding(nn.Module): def __init__(self, tile_classes=32, width=13, height=13, hidden_dim=128, output_dim=256): super().__init__() - self.tile_embedding = nn.Embedding(tile_classes, hidden_dim) - self.col_embedding = nn.Embedding(width, hidden_dim) - self.row_embedding = nn.Embedding(height, hidden_dim) + self.tile_embedding = nn.Sequential( + nn.Embedding(tile_classes, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.GELU() + ) + self.col_embedding = nn.Sequential( + nn.Embedding(width, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.GELU() + ) + self.row_embedding = nn.Sequential( + nn.Embedding(height, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.GELU() + ) self.fusion = nn.Linear(hidden_dim * 3, output_dim) def forward(self, tile, x, y): @@ -43,6 +55,25 @@ class EncoderGRU(nn.Module): hidden = self.drop(self.gru(feat, hidden)) logits = self.fc(hidden) return logits, hidden + +class EncoderFusion(nn.Module): + def __init__(self, d_model=256): + super().__init__() + + self.transformer = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model=d_model, dim_feedforward=d_model, nhead=2, batch_first=True + ), + num_layers=2 + ) + self.norm = nn.LayerNorm(d_model) + + def forward(self, logits): + x = self.norm(self.transformer(logits)) + h_mean = torch.mean(x, dim=1) + h_max = torch.max(x, dim=1).values + h = torch.cat([h_mean, h_max], dim=1) + return h class VAEEncoder(nn.Module): def __init__(self, device, tile_classes=32, latent_dim=32, width=13, height=13): @@ -54,6 +85,7 @@ class VAEEncoder(nn.Module): self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256) self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim) + self.fusion = EncoderFusion(256) self.fc_mu = nn.Linear(512, latent_dim) self.fc_logvar = nn.Linear(512, latent_dim) @@ -79,9 +111,8 @@ class VAEEncoder(nn.Module): logits, h = self.rnn(embed[:, idx], hidden) hidden = h output[:, idx] = logits - h_mean = torch.mean(output, dim=1) - h_max = torch.max(output, dim=1).values - h = torch.cat([h_mean, h_max], dim=1) + + h = self.fusion(output) mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvar @@ -107,4 +138,5 @@ if __name__ == "__main__": print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}") print(f"Embedding parameters: {sum(p.numel() for p in model.embedding.parameters())}") print(f"RNN parameters: {sum(p.numel() for p in model.rnn.parameters())}") + print(f"Fusion parameters: {sum(p.numel() for p in model.fusion.parameters())}") print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/ginka/vae_rnn/vae.py b/ginka/vae_rnn/vae.py index e183f6c..5ee9531 100644 --- a/ginka/vae_rnn/vae.py +++ b/ginka/vae_rnn/vae.py @@ -1,8 +1,10 @@ +import time import torch import torch.nn as nn import torch.nn.functional as F from .encoder import VAEEncoder from .decoder import VAEDecoder +from ..utils import print_memory class GinkaVAE(nn.Module): def __init__(self, device, tile_classes=32, latent_dim=32): @@ -19,4 +21,27 @@ class GinkaVAE(nn.Module): mu, logvar = self.encoder(target_map) z = self.reparameterize(mu, logvar) logits = self.decoder(z, target_map, use_self_probility) - return logits, mu, logvar \ No newline at end of file + return logits, mu, logvar + +if __name__ == "__main__": + device = torch.device("cpu") + + input = torch.randint(0, 32, [1, 13, 13]).to(device) + + # 初始化模型 + model = GinkaVAE(device).to(device) + + print_memory("初始化后") + + # 前向传播 + start = time.perf_counter() + logits, mu, logvar = model(input) + end = time.perf_counter() + + print_memory("前向传播后") + + print(f"推理耗时: {end - start}") + print(f"输出形状: logits= {logits.shape}, mu={mu.shape}, logvar={logvar.shape}") + print(f"Encoder parameters: {sum(p.numel() for p in model.encoder.parameters())}") + print(f"Decoder parameters: {sum(p.numel() for p in model.decoder.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")