chore: 调整参数

This commit is contained in:
unanmed 2026-02-03 13:25:41 +08:00
parent 14c46b0ceb
commit 10166fa073
2 changed files with 11 additions and 9 deletions

View File

@ -53,6 +53,8 @@ from shared.image import matrix_to_image_cv
# 29. 入口,不区分楼梯和箭头
BATCH_SIZE = 128
LATENT_DIM = 48
KL_BETA = 0.01
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
@ -63,7 +65,7 @@ disable_tqdm = not sys.stdout.isatty()
def gt_prob(epoch: int, max_epoch: int) -> float:
progress = epoch / max_epoch
return max(2 * progress - 1, 0)
return 1
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
@ -82,14 +84,14 @@ def train():
args = parse_arguments()
vae = GinkaVAE(device).to(device)
vae = GinkaVAE(device, latent_dim=LATENT_DIM).to(device)
dataset = GinkaRNNDataset(args.train, device)
dataset_val = GinkaRNNDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 4, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=3e-4, weight_decay=1e-4)
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
criterion = VAELoss()
@ -121,10 +123,10 @@ def train():
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs))
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05)
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
loss.backward()
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=2.0)
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
optimizer_ginka.step()
loss_total += loss.detach()
reco_loss_total += reco_loss.detach()
@ -163,7 +165,7 @@ def train():
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs))
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, 0.05)
loss, reco_loss, kl_loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
val_loss_total += loss.detach()
reco_loss_total += reco_loss.detach()
kl_loss_total += kl_loss.detach()
@ -179,7 +181,7 @@ def train():
# 随机采样
for i in range(0, 8):
z = torch.randn(1, 32).to(device)
z = torch.randn(1, LATENT_DIM).to(device)
fake_logits = vae.decoder(z, torch.zeros(1, 13, 13).to(device), 1)
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()

View File

@ -10,7 +10,7 @@ class GinkaVAE(nn.Module):
def __init__(self, device, tile_classes=32, latent_dim=32):
super().__init__()
self.encoder = VAEEncoder(device, tile_classes, latent_dim)
self.decoder = VAEDecoder(device)
self.decoder = VAEDecoder(device, map_vec_dim=latent_dim)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)