diff --git a/ginka/train_vae.py b/ginka/train_vae.py index 52da79b..6a92426 100644 --- a/ginka/train_vae.py +++ b/ginka/train_vae.py @@ -53,6 +53,8 @@ from shared.image import matrix_to_image_cv # 29. 入口,不区分楼梯和箭头 BATCH_SIZE = 128 +LATENT_DIM = 48 +KL_BETA = 0.01 device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) @@ -63,7 +65,7 @@ disable_tqdm = not sys.stdout.isatty() def gt_prob(epoch: int, max_epoch: int) -> float: progress = epoch / max_epoch - return max(2 * progress - 1, 0) + return 1 def parse_arguments(): parser = argparse.ArgumentParser(description="training codes") @@ -82,14 +84,14 @@ def train(): args = parse_arguments() - vae = GinkaVAE(device).to(device) + vae = GinkaVAE(device, latent_dim=LATENT_DIM).to(device) dataset = GinkaRNNDataset(args.train, device) dataset_val = GinkaRNNDataset(args.validate, device) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4, shuffle=True) + dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True) - optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-4) + optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4) scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6) criterion = VAELoss() @@ -121,10 +123,10 @@ def train(): fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs)) - loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05) + loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) loss.backward() - torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=2.0) + torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0) optimizer_ginka.step() loss_total += loss.detach() reco_loss_total += reco_loss.detach() @@ -163,7 +165,7 @@ def train(): fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs)) - loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05) + loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA) val_loss_total += loss.detach() reco_loss_total += reco_loss.detach() kl_loss_total += kl_loss.detach() @@ -179,7 +181,7 @@ def train(): # 随机采样 for i in range(0, 8): - z = torch.randn(1, 32).to(device) + z = torch.randn(1, LATENT_DIM).to(device) fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1) fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy() diff --git a/ginka/vae_rnn/vae.py b/ginka/vae_rnn/vae.py index 5ee9531..d6550b9 100644 --- a/ginka/vae_rnn/vae.py +++ b/ginka/vae_rnn/vae.py @@ -10,7 +10,7 @@ class GinkaVAE(nn.Module): def __init__(self, device, tile_classes=32, latent_dim=32): super().__init__() self.encoder = VAEEncoder(device, tile_classes, latent_dim) - self.decoder = VAEDecoder(device) + self.decoder = VAEDecoder(device, map_vec_dim=latent_dim) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar)