mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 报错
This commit is contained in:
parent
973434553a
commit
a119cbb155
@ -45,12 +45,12 @@ from shared.image import matrix_to_image_cv
|
||||
# 0. 空地, 1. 墙壁, 2. 门, 3. 钥匙, 4. 红宝石, 5. 蓝宝石, 6. 绿宝石, 7. 血瓶
|
||||
# 8. 道具, 9. 怪物, 10. 入口, 14. 起始 token, 15. 终止 token
|
||||
|
||||
BATCH_SIZE = 8
|
||||
LATENT_DIM = 48
|
||||
KL_BETA = 0.1
|
||||
BATCH_SIZE = 128
|
||||
LATENT_DIM = 32
|
||||
KL_BETA = 0.01
|
||||
SELF_GATE = 0.5
|
||||
GATE_EPOCH = 5
|
||||
VAL_BATCH_DIVIDER = 8
|
||||
VAL_BATCH_DIVIDER = 128
|
||||
PROB_STEP = 0.05
|
||||
NUM_CLASSES = 16
|
||||
|
||||
@ -89,7 +89,7 @@ def train():
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // VAL_BATCH_DIVIDER, shuffle=True)
|
||||
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-2, betas=(0.9, 0.95))
|
||||
# 自定义调度器允许在 self_prob 提高时重置调度器信息并提高学习率以适应学习
|
||||
scheduler_ginka = VAEScheduler(
|
||||
optimizer_ginka, factor=0.9, increase_factor=2, patience=10, max_lr=1e-4, min_lr=1e-6
|
||||
@ -179,7 +179,7 @@ def train():
|
||||
scheduler_ginka.step(avg_loss, self_prob)
|
||||
|
||||
# 每若干轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % 1 == 0:
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
# 保存检查点
|
||||
torch.save({
|
||||
"model_state": vae.state_dict(),
|
||||
@ -208,7 +208,7 @@ def train():
|
||||
val_reco_loss_total += reco_loss.detach()
|
||||
val_kl_loss_total += kl_loss.detach()
|
||||
|
||||
fake_map = torch.argmax(fake_logits, dim=2).view(B, H, W).cpu().numpy()
|
||||
fake_map = torch.argmax(fake_logits, dim=2)[:,0:169].view(B, H, W).cpu().numpy()
|
||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
||||
real_map = target_map.cpu().numpy()
|
||||
real_img = matrix_to_image_cv(real_map[0], tile_dict)
|
||||
@ -223,7 +223,7 @@ def train():
|
||||
|
||||
vae.autoregressive()
|
||||
fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device))
|
||||
fake_map = fake_logits.view(-1, 13, 13).cpu().numpy()
|
||||
fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy()
|
||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
||||
|
||||
cv2.imwrite(f"result/ginka_vae_img/{i}_rand.png", fake_img)
|
||||
@ -244,7 +244,7 @@ def train():
|
||||
for t in torch.linspace(0, 1, 8):
|
||||
z = z1 * (1 - t / 8) + z2 * t / 8
|
||||
fake_logits = vae.decoder(z, torch.zeros(1, 169).to(device))
|
||||
fake_map = fake_logits.view(-1, 13, 13).cpu().numpy()
|
||||
fake_map = fake_logits[:,0:169].view(-1, 13, 13).cpu().numpy()
|
||||
fake_img = matrix_to_image_cv(fake_map[0], tile_dict)
|
||||
img = np.block([[real_img1], [vline], [fake_img], [vline], [real_img2]])
|
||||
|
||||
|
||||
@ -7,7 +7,8 @@ class VAELoss:
|
||||
|
||||
def vae_loss(self, logits, target, mu, logvar, beta=0.1):
|
||||
# target: [B, 169]
|
||||
end_token = torch.tensor([15], dtype=torch.long).to(logits.device)
|
||||
B, L = target.shape
|
||||
end_token = torch.tensor([15], dtype=torch.long).to(logits.device).repeat(B, 1)
|
||||
target = torch.cat([target, end_token], dim=1)
|
||||
target = F.one_hot(target, num_classes=self.num_classes).float()
|
||||
recon_loss = F.cross_entropy(logits, target)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user