From 1962c7a71247cc3e0ca0348d9229e45113fb746b Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 15 Dec 2025 12:41:39 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=8D=9F=E5=A4=B1=E5=80=BC=E8=AE=A1?= =?UTF-8?q?=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/generator/loss.py | 6 +++--- ginka/train_rnn.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 363c7f8..60a6347 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -403,12 +403,12 @@ class WGANGinkaLoss: return sum(losses) class RNNGinkaLoss: - def __init__(self, num_classes): + def __init__(self, num_classes, device): self.num_classes = num_classes weight = torch.ones(self.num_classes) weight[0] = 0.3 weight[1] = 0.5 - self.weight = weight + self.weight = weight.to(device) pass def rnn_loss(self, fake, target): @@ -416,5 +416,5 @@ class RNNGinkaLoss: fake: [B, C, H, W] target: [B, H, W] """ - target = F.one_hot(target, num_classes=self.num_classes).float() + target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2) return F.cross_entropy(fake, target, label_smoothing=0.1, weight=self.weight) diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index a7db7b5..7970009 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -86,7 +86,7 @@ def train(): optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4) scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6) - criterion = RNNGinkaLoss(32) + criterion = RNNGinkaLoss(32, device) # 用于生成图片 tile_dict = dict()