diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 363c7f8..60a6347 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -403,12 +403,12 @@ class WGANGinkaLoss: return sum(losses) class RNNGinkaLoss: - def __init__(self, num_classes): + def __init__(self, num_classes, device): self.num_classes = num_classes weight = torch.ones(self.num_classes) weight[0] = 0.3 weight[1] = 0.5 - self.weight = weight + self.weight = weight.to(device) pass def rnn_loss(self, fake, target): @@ -416,5 +416,5 @@ class RNNGinkaLoss: fake: [B, C, H, W] target: [B, H, W] """ - target = F.one_hot(target, num_classes=self.num_classes).float() + target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2) return F.cross_entropy(fake, target, label_smoothing=0.1, weight=self.weight) diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index a7db7b5..7970009 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -86,7 +86,7 @@ def train(): optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4) scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6) - criterion = RNNGinkaLoss(32) + criterion = RNNGinkaLoss(32, device) # 用于生成图片 tile_dict = dict()