From a119cbb155f51b0cf359d652360f6c6ede2c976f Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 10 Mar 2026 18:24:40 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_transformer_vae.py | 18 +++++++++--------- ginka/vae_rnn/loss.py | 3 ++- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/ginka/train_transformer_vae.py b/ginka/train_transformer_vae.py index 84da62f..820cbfa 100644 --- a/ginka/train_transformer_vae.py +++ b/ginka/train_transformer_vae.py @@ -45,12 +45,12 @@ from shared.image import matrix_to_image_cv # 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶 # 8. 道具, 9. 怪物, 10. 入口, 14. 起始 token, 15. 终止 token -BATCH_SIZE = 8 -LATENT_DIM = 48 -KL_BETA = 0.1 +BATCH_SIZE = 128 +LATENT_DIM = 32 +KL_BETA = 0.01 SELF_GATE = 0.5 GATE_EPOCH = 5 -VAL_BATCH_DIVIDER = 8 +VAL_BATCH_DIVIDER = 128 PROB_STEP = 0.05 NUM_CLASSES = 16 @@ -89,7 +89,7 @@ def train(): dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True) - optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4) + optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-2, betas=(0.9, 0.95)) # 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习 scheduler_ginka = VAEScheduler( optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=1e-4, min_lr=1e-6 @@ -179,7 +179,7 @@ def train(): scheduler_ginka.step(avg_loss, self_prob) # 每若干轮输出一次图片,并保存检查点 - if (epoch + 1) % 1 == 0: + if (epoch + 1) % args.checkpoint == 0: # 保存检查点 torch.save({ "model_state": vae.state_dict(), @@ -208,7 +208,7 @@ def train(): val_reco_loss_total += reco_loss.detach() val_kl_loss_total += kl_loss.detach() - fake_map = torch.argmax(fake_logits, dim=2).view(B, H, W).cpu().numpy() + fake_map = torch.argmax(fake_logits, dim=2)[:,0:169].view(B, H, W).cpu().numpy() fake_img = matrix_to_image_cv(fake_map[0], tile_dict) real_map = target_map.cpu().numpy() real_img = matrix_to_image_cv(real_map[0], tile_dict) @@ -223,7 +223,7 @@ def train(): vae.autoregressive() fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device)) - fake_map = fake_logits.view(-1, 13, 13).cpu().numpy() + fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy() fake_img = matrix_to_image_cv(fake_map[0], tile_dict) cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img) @@ -244,7 +244,7 @@ def train(): for t in torch.linspace(0, 1, 8): z = z1 * (1 - t / 8) + z2 * t / 8 fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device)) - fake_map = fake_logits.view(-1, 13, 13).cpu().numpy() + fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy() fake_img = matrix_to_image_cv(fake_map[0], tile_dict) img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]]) diff --git a/ginka/vae_rnn/loss.py b/ginka/vae_rnn/loss.py index 4ebebee..381f340 100644 --- a/ginka/vae_rnn/loss.py +++ b/ginka/vae_rnn/loss.py @@ -7,7 +7,8 @@ class VAELoss: def vae_loss(self, logits, target, mu, logvar, beta=0.1): # target: [B, 169] - end_token = torch.tensor([15], dtype=torch.long).to(logits.device) + B, L = target.shape + end_token = torch.tensor([15], dtype=torch.long).to(logits.device).repeat(B, 1) target = torch.cat([target, end_token], dim=1) target = F.one_hot(target, num_classes=self.num_classes).float() recon_loss = F.cross_entropy(logits, target)