From ae531946943e818407721f3f978af4ca36e2379e Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sat, 13 Dec 2025 19:14:48 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E4=BF=AE=E6=94=B9=20batch=5Fsize?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/generator/rnn.py | 2 +- ginka/train_rnn.py | 9 ++++----- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index 69cf75c..1c1ff92 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -138,7 +138,7 @@ class GinkaRNN(nn.Module): """ hidden = self.gru(feat_fusion, hidden) logits = self.fc(hidden) - return logits, hidden + return F.sigmoid(logits), hidden class GinkaRNNModel(nn.Module): def __init__(self, device: torch.device, start_tile=31, width=13, height=13): diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index 45b0d62..8807c08 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -50,7 +50,7 @@ from shared.image import matrix_to_image_cv # 26-28. 三种等级的怪物 # 29. 入口,不区分楼梯和箭头 -BATCH_SIZE = 8 +BATCH_SIZE = 96 device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) @@ -114,8 +114,7 @@ def train(): val_cond = batch["val_cond"].to(device) target_map = batch["target_map"].to(device) - with torch.autograd.set_detect_anomaly(True): - fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) + fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) loss = criterion.rnn_loss(fake_logits, target_map) @@ -125,7 +124,7 @@ def train(): iters += 1 - # if iters % 100 == 0: + # if iters % 50 == 0: # avg_loss_ginka = loss_total_ginka.item() / iters # tqdm.write( @@ -160,7 +159,7 @@ def train(): B, T = val_cond.shape fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) - + val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach() fake_map = fake_map.cpu().numpy()