fix: 报错

This commit is contained in:
unanmed 2026-03-10 18:24:40 +08:00
parent 973434553a
commit a119cbb155
2 changed files with 11 additions and 10 deletions

View File

@ -45,12 +45,12 @@ from shared.image import matrix_to_image_cv
# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶
# 8. 道具, 9. 怪物, 10. 入口, 14. 起始 token, 15. 终止 token
BATCH_SIZE = 8
LATENT_DIM = 48
KL_BETA = 0.1
BATCH_SIZE = 128
LATENT_DIM = 32
KL_BETA = 0.01
SELF_GATE = 0.5
GATE_EPOCH = 5
VAL_BATCH_DIVIDER = 8
VAL_BATCH_DIVIDER = 128
PROB_STEP = 0.05
NUM_CLASSES = 16
@ -89,7 +89,7 @@ def train():
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-2, betas=(0.9, 0.95))
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
scheduler_ginka = VAEScheduler(
optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=1e-4, min_lr=1e-6
@ -179,7 +179,7 @@ def train():
scheduler_ginka.step(avg_loss, self_prob)
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % 1 == 0:
if (epoch + 1) % args.checkpoint == 0:
# 保存检查点
torch.save({
"model_state": vae.state_dict(),
@ -208,7 +208,7 @@ def train():
val_reco_loss_total += reco_loss.detach()
val_kl_loss_total += kl_loss.detach()
fake_map = torch.argmax(fake_logits, dim=2).view(B, H, W).cpu().numpy()
fake_map = torch.argmax(fake_logits, dim=2)[:,0:169].view(B, H, W).cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
real_map = target_map.cpu().numpy()
real_img = matrix_to_image_cv(real_map[0], tile_dict)
@ -223,7 +223,7 @@ def train():
vae.autoregressive()
fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device))
fake_map = fake_logits.view(-1, 13, 13).cpu().numpy()
fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img)
@ -244,7 +244,7 @@ def train():
for t in torch.linspace(0, 1, 8):
z = z1 * (1 - t / 8) + z2 * t / 8
fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device))
fake_map = fake_logits.view(-1, 13, 13).cpu().numpy()
fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy()
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]])

View File

@ -7,7 +7,8 @@ class VAELoss:
def vae_loss(self, logits, target, mu, logvar, beta=0.1):
# target: [B, 169]
end_token = torch.tensor([15], dtype=torch.long).to(logits.device)
B, L = target.shape
end_token = torch.tensor([15], dtype=torch.long).to(logits.device).repeat(B, 1)
target = torch.cat([target, end_token], dim=1)
target = F.one_hot(target, num_classes=self.num_classes).float()
recon_loss = F.cross_entropy(logits, target)