diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index e8eaf53..ed921eb 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -161,7 +161,7 @@ def train(): val_cond = batch["val_cond"].to(device) target_map = batch["target_map"].to(device) - fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epoch)) + fake_logits, fake_map = ginka_rnn(val_cond, target_map, 1 - gt_prob(epoch, args.epochs)) val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()