mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-20 17:22:41 +08:00
feat: 更改调度
This commit is contained in:
parent
bc73cf9cc3
commit
b01357d99e
@ -54,7 +54,6 @@ from shared.image import matrix_to_image_cv
|
|||||||
|
|
||||||
BATCH_SIZE = 128
|
BATCH_SIZE = 128
|
||||||
LATENT_DIM = 48
|
LATENT_DIM = 48
|
||||||
KL_BETA = 0.01
|
|
||||||
|
|
||||||
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
@ -63,15 +62,6 @@ os.makedirs("result/ginka_vae_img", exist_ok=True)
|
|||||||
|
|
||||||
disable_tqdm = not sys.stdout.isatty()
|
disable_tqdm = not sys.stdout.isatty()
|
||||||
|
|
||||||
def gt_prob(epoch: int, max_epoch: int) -> float:
|
|
||||||
progress = epoch / max_epoch
|
|
||||||
if progress < 0.2:
|
|
||||||
return 1
|
|
||||||
elif progress < 0.8:
|
|
||||||
return 1 - (progress - 0.2) / 0.6
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
parser = argparse.ArgumentParser(description="training codes")
|
parser = argparse.ArgumentParser(description="training codes")
|
||||||
parser.add_argument("--resume", type=bool, default=False)
|
parser.add_argument("--resume", type=bool, default=False)
|
||||||
@ -96,10 +86,12 @@ def train():
|
|||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 64, shuffle=True)
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=1e-4)
|
optimizer_ginka = optim.AdamW(vae.parameters(), lr=2e-4, weight_decay=1e-4)
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
scheduler_ginka = optim.lr_scheduler.ReduceLROnPlateau(optimizer_ginka, factor=0.9, patience=40)
|
||||||
|
|
||||||
criterion = VAELoss()
|
criterion = VAELoss()
|
||||||
|
|
||||||
|
gt_prob = 1
|
||||||
|
|
||||||
# 用于生成图片
|
# 用于生成图片
|
||||||
tile_dict = dict()
|
tile_dict = dict()
|
||||||
@ -124,9 +116,9 @@ def train():
|
|||||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
fake_logits, mu, logvar = vae(target_map, 1 - gt_prob(epoch, args.epochs))
|
fake_logits, z = vae(target_map, 1 - gt_prob)
|
||||||
|
|
||||||
loss = criterion.vae_loss(fake_logits, target_map, mu, logvar, KL_BETA)
|
loss = criterion.vae_loss(fake_logits, target_map)
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
|
torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
|
||||||
@ -140,7 +132,7 @@ def train():
|
|||||||
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler_ginka.step()
|
scheduler_ginka.step(avg_loss)
|
||||||
|
|
||||||
# 每若干轮输出一次图片,并保存检查点
|
# 每若干轮输出一次图片,并保存检查点
|
||||||
if (epoch + 1) % args.checkpoint == 0:
|
if (epoch + 1) % args.checkpoint == 0:
|
||||||
@ -151,7 +143,6 @@ def train():
|
|||||||
}, f"result/rnn/ginka-{epoch + 1}.pth")
|
}, f"result/rnn/ginka-{epoch + 1}.pth")
|
||||||
|
|
||||||
val_loss_total = torch.Tensor([0]).to(device)
|
val_loss_total = torch.Tensor([0]).to(device)
|
||||||
reco_loss_total = torch.Tensor([0]).to(device)
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
idx = 0
|
idx = 0
|
||||||
gap = 5
|
gap = 5
|
||||||
@ -161,9 +152,9 @@ def train():
|
|||||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||||
target_map = batch["target_map"].to(device)
|
target_map = batch["target_map"].to(device)
|
||||||
|
|
||||||
fake_logits, z = vae(target_map, 1 - gt_prob(epoch, args.epochs))
|
fake_logits, z = vae(target_map, 1 - gt_prob)
|
||||||
|
|
||||||
loss = criterion.vae_loss(fake_logits, target_map, z, KL_BETA)
|
loss = criterion.vae_loss(fake_logits, target_map)
|
||||||
val_loss_total += loss.detach()
|
val_loss_total += loss.detach()
|
||||||
|
|
||||||
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
|
fake_map = torch.argmax(fake_logits, dim=1).cpu().numpy()
|
||||||
@ -212,6 +203,9 @@ def train():
|
|||||||
f"Loss: {avg_loss_val:.6f}"
|
f"Loss: {avg_loss_val:.6f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if avg_loss_val < 0.5 and gt_prob > 0:
|
||||||
|
gt_prob -= 0.1
|
||||||
|
|
||||||
print("Train ended.")
|
print("Train ended.")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": vae.state_dict(),
|
"model_state": vae.state_dict(),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user