From e66919d11ef1b5a25bc6e33f8bde231ffcbb6e66 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sun, 8 Mar 2026 16:05:07 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=BE=AE=E8=B0=83=20rnn-vae?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_vae.py | 14 ++++++++------ ginka/vae_rnn/decoder.py | 17 +++++++++++++---- ginka/vae_rnn/encoder.py | 29 ++++++++++++++++++++--------- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/ginka/train_vae.py b/ginka/train_vae.py index 238321f..ae539d8 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -54,10 +54,10 @@ from shared.image import matrix_to_image_cv # 29. 入口,不区分楼梯和箭头 BATCH_SIZE = 128 -LATENT_DIM = 48 -KL_BETA = 0.1 -SELF_GATE = 0.5 -GATE_EPOCH = 5 +LATENT_DIM = 64 +KL_BETA = 0.01 +SELF_GATE = 0.3 +GATE_EPOCH = 10 VAL_BATCH_DIVIDER = 128 PROB_STEP = 0.05 @@ -96,10 +96,10 @@ def train(): dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True) - optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4) + optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-4) # 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习 scheduler_ginka = VAEScheduler( - optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=2e-4, min_lr=1e-6 + optimizer_ginka, factor=0.9, increase_factor=1.5, patience=20, max_lr=3e-4, min_lr=1e-6 ) criterion = VAELoss() @@ -129,6 +129,7 @@ def train(): reco_loss_total = torch.Tensor([0]).to(device) kl_loss_total = torch.Tensor([0]).to(device) + vae.train() for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): target_map = batch["target_map"].to(device) @@ -182,6 +183,7 @@ def train(): # 每若干轮输出一次图片,并保存检查点 if (epoch + 1) % args.checkpoint == 0: + vae.eval() # 保存检查点 torch.save({ "model_state": vae.state_dict(), diff --git a/ginka/vae_rnn/decoder.py b/ginka/vae_rnn/decoder.py index 2317333..09d355e 100644 --- a/ginka/vae_rnn/decoder.py +++ b/ginka/vae_rnn/decoder.py @@ -105,7 +105,13 @@ class DecoderInputFusion(nn.Module): ) self.norm = nn.LayerNorm(d_model) self.fusion = nn.Sequential( + nn.Linear(d_model * 2, d_model * 2), + nn.Dropout(0.2), + nn.LayerNorm(d_model * 2), + nn.GELU(), + nn.Linear(d_model * 2, d_model), + nn.Dropout(0.1), nn.LayerNorm(d_model), nn.GELU() ) @@ -138,6 +144,7 @@ class DecoderRNN(nn.Module): self.drop = nn.Dropout(0.2) self.fc = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), + nn.Dropout(0.1), nn.LayerNorm(hidden_dim), nn.GELU(), @@ -168,8 +175,10 @@ class VAEDecoder(nn.Module): # 模型结构 self.map_vec_fc = nn.Sequential( nn.Linear(map_vec_dim, 128), + nn.Dropout(0.1), nn.LayerNorm(128), nn.GELU(), + nn.Linear(128, 256) ) self.tile_embedding = DecoderTileEmbedding(tile_classes=self.tile_classes) @@ -227,22 +236,22 @@ class VAEDecoder(nn.Module): return output_logits.permute(0, 3, 1, 2) if __name__ == "__main__": - device = torch.device("cpu") + device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") input = torch.randint(0, 32, [1, 13, 13]).to(device) map_vec = torch.rand(1, 32).to(device) # 初始化模型 - model = VAEDecoder("cpu").to(device) + model = VAEDecoder(device).to(device) - print_memory("初始化后") + print_memory(device, "初始化后") # 前向传播 start = time.perf_counter() fake_logits = model(map_vec, input, 0) end = time.perf_counter() - print_memory("前向传播后") + print_memory(device, "前向传播后") print(f"推理耗时: {end - start}") print(f"输出形状: fake_logits={fake_logits.shape}") diff --git a/ginka/vae_rnn/encoder.py b/ginka/vae_rnn/encoder.py index 61332bf..de8717e 100644 --- a/ginka/vae_rnn/encoder.py +++ b/ginka/vae_rnn/encoder.py @@ -38,9 +38,10 @@ class EncoderGRU(nn.Module): # GRU self.gru = nn.GRUCell(input_dim, hidden_dim) - self.drop = nn.Dropout(0.1) + self.drop = nn.Dropout(0.2) self.fc = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), + nn.Dropout(0.1), nn.LayerNorm(hidden_dim), nn.GELU(), @@ -62,13 +63,14 @@ class EncoderFusion(nn.Module): self.transformer = nn.TransformerEncoder( nn.TransformerEncoderLayer( - d_model=d_model, dim_feedforward=d_model, nhead=2, batch_first=True + d_model=d_model, dim_feedforward=d_model*2, nhead=2, batch_first=True ), - num_layers=2 + num_layers=3 ) self.norm = nn.LayerNorm(d_model) self.fc = nn.Sequential( nn.Linear(d_model * 2, d_model * 2), + nn.Dropout(0.1), nn.LayerNorm(d_model * 2), nn.GELU() ) @@ -91,11 +93,20 @@ class VAEEncoder(nn.Module): self.embedding = EncoderEmbedding(tile_classes, width, height, 128, 256) self.rnn = EncoderGRU(256, self.rnn_hidden, self.logits_dim) self.fusion = EncoderFusion(256) - self.fc = nn.Sequential( + self.fc_mu = nn.Sequential( + nn.Linear(512, 512), + nn.Dropout(0.1), + nn.LayerNorm(512), + nn.GELU(), + nn.Linear(512, latent_dim) + ) + self.fc_logvar = nn.Sequential( + nn.Linear(512, 512), + nn.Dropout(0.1), + nn.LayerNorm(512), + nn.GELU(), nn.Linear(512, latent_dim) ) - self.fc_mu = nn.Linear(512, latent_dim) - self.fc_logvar = nn.Linear(512, latent_dim) self.col_list = [] self.row_list = [] @@ -126,21 +137,21 @@ class VAEEncoder(nn.Module): return mu, logvar if __name__ == "__main__": - device = torch.device("cpu") + device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") input = torch.randint(0, 32, [1, 13, 13]).to(device) # 初始化模型 model = VAEEncoder(device).to(device) - print_memory("初始化后") + print_memory(device, "初始化后") # 前向传播 start = time.perf_counter() mu, logvar = model(input) end = time.perf_counter() - print_memory("前向传播后") + print_memory(device, "前向传播后") print(f"推理耗时: {end - start}") print(f"输出形状: mu={mu.shape}, logvar={logvar.shape}")