mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
chore: 调整参数
This commit is contained in:
parent
14c46b0ceb
commit
10166fa073
@ -53,6 +53,8 @@ from shared.image import matrix_to_image_cv
|
||||
# 29. 入口,不区分楼梯和箭头
|
||||
|
||||
BATCH_SIZE = 128
|
||||
LATENT_DIM = 48
|
||||
KL_BETA = 0.01
|
||||
|
||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
@ -63,7 +65,7 @@ disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
def gt_prob(epoch: int, max_epoch: int) -> float:
|
||||
progress = epoch / max_epoch
|
||||
return max(2 * progress - 1, 0)
|
||||
return 1
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
@ -82,14 +84,14 @@ def train():
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
vae = GinkaVAE(device).to(device)
|
||||
vae = GinkaVAE(device, latent_dim=LATENT_DIM).to(device)
|
||||
|
||||
dataset = GinkaRNNDataset(args.train, device)
|
||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
|
||||
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-4)
|
||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
||||
|
||||
criterion = VAELoss()
|
||||
@ -121,10 +123,10 @@ def train():
|
||||
|
||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs))
|
||||
|
||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05)
|
||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
||||
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=2.0)
|
||||
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
|
||||
optimizer_ginka.step()
|
||||
loss_total += loss.detach()
|
||||
reco_loss_total += reco_loss.detach()
|
||||
@ -163,7 +165,7 @@ def train():
|
||||
|
||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs))
|
||||
|
||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05)
|
||||
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
||||
val_loss_total += loss.detach()
|
||||
reco_loss_total += reco_loss.detach()
|
||||
kl_loss_total += kl_loss.detach()
|
||||
@ -179,7 +181,7 @@ def train():
|
||||
|
||||
# 随机采样
|
||||
for i in range(0, 8):
|
||||
z = torch.randn(1, 32).to(device)
|
||||
z = torch.randn(1, LATENT_DIM).to(device)
|
||||
|
||||
fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1)
|
||||
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
|
||||
|
||||
@ -10,7 +10,7 @@ class GinkaVAE(nn.Module):
|
||||
def __init__(self, device, tile_classes=32, latent_dim=32):
|
||||
super().__init__()
|
||||
self.encoder = VAEEncoder(device, tile_classes, latent_dim)
|
||||
self.decoder = VAEDecoder(device)
|
||||
self.decoder = VAEDecoder(device, map_vec_dim=latent_dim)
|
||||
|
||||
def reparameterize(self, mu, logvar):
|
||||
std = torch.exp(0.5 * logvar)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user