From d0decfc63a52fb5d001d92d75333b910caf38ebc Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 17 Dec 2025 13:12:44 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20device=20=E4=B8=8D=E6=AD=A3=E7=A1=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/generator/loss.py | 2 +- ginka/generator/rnn.py | 2 +- ginka/train_rnn.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 60a6347..66a68ee 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -417,4 +417,4 @@ class RNNGinkaLoss: target: [B, H, W] """ target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2) - return F.cross_entropy(fake, target, label_smoothing=0.1, weight=self.weight) + return F.cross_entropy(fake, target, label_smoothing=0.1) diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index 654cae2..7d430d9 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -72,7 +72,7 @@ class GinkaMapPatch(nn.Module): mask[:, 4, 2] = 0 mask[:, 4, 3] = 0 mask[:, 4, 4] = 0 - masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]) + masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]).to(map.device) masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float() masked_result[:, 32] = mask diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index 7970009..d450e9b 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -81,7 +81,7 @@ def train(): dataset = GinkaRNNDataset(args.train, device) dataset_val = GinkaRNNDataset(args.validate, device) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) + dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 8) optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4) scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)