diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index 933d143..8badc99 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -8,7 +8,7 @@ import cv2 from torch_geometric.loader import DataLoader from tqdm import tqdm from .common.cond import ConditionEncoder -from .generator.rnn import GinkaRNN +from .generator.rnn import GinkaRNNModel from .dataset import GinkaRNNDataset from .generator.loss import RNNGinkaLoss from shared.image import matrix_to_image_cv @@ -76,15 +76,14 @@ def train(): args = parse_arguments() - cond_inj = ConditionEncoder().to(device) - ginka_rnn = GinkaRNN().to(device) + ginka_rnn = GinkaRNNModel(device).to(device) dataset = GinkaRNNDataset(args.train, device) dataset_val = GinkaRNNDataset(args.validate, device) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) - optimizer_ginka = optim.Adam(list(ginka_rnn.parameters()) + list(cond_inj.parameters()), lr=1e-3, betas=(0.0, 0.9)) + optimizer_ginka = optim.Adam(ginka_rnn.parameters(), lr=1e-3) scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2) criterion = RNNGinkaLoss() @@ -112,16 +111,12 @@ def train(): iters = 0 for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - tag_cond = batch["tag_cond"].to(device) val_cond = batch["val_cond"].to(device) target_map = batch["target_map"].to(device) - B, D = val_cond.shape - stage = torch.Tensor([0]).expand(B, 1).to(device) - cond_vec = cond_inj(tag_cond, val_cond, stage) - fake = ginka_rnn(target_map, cond_vec) + fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) - loss = criterion.rnn_loss(fake, target_map) + loss = criterion.rnn_loss(fake_logits, target_map) loss.backward() optimizer_ginka.step() @@ -159,23 +154,25 @@ def train(): with torch.no_grad(): idx = 0 for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): - tag_cond = batch["tag_cond"].to(device) val_cond = batch["val_cond"].to(device) target_map = batch["target_map"].to(device) B, T = val_cond.shape - stage = torch.Tensor([0]).expand(B, 1).to(device) - cond_vec = cond_inj(tag_cond, val_cond, stage) - fake = ginka_rnn(target_map, cond_vec) + fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) - val_loss_total += criterion.rnn_loss(fake, target_map).detach() + val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach() - B, T, C = fake.shape - fake_map = torch.argmax(fake, dim=-1).reshape(B, 13, 13).cpu().numpy() + fake_map = fake_map.cpu().numpy() fake_img = matrix_to_image_cv(fake_map[0], tile_dict) cv2.imwrite(f"result/ginka_rnn_img/{idx}.png", fake_img) idx += 1 + + avg_loss_val = val_loss_total.item() / len(dataloader_val) + tqdm.write( + f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] E: {epoch} | " + + f"Loss: {avg_loss_val:.6f}" + ) print("Train ended.") torch.save({