From 3946d83d6c421f0e27357bc0111ae8f4aec1425c Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 15 Dec 2025 12:32:14 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BA=A4=E5=8F=89=E7=86=B5=E6=8D=9F?= =?UTF-8?q?=E5=A4=B1=E8=AE=A1=E7=AE=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/generator/loss.py | 9 +++++---- ginka/generator/rnn.py | 2 +- ginka/train_rnn.py | 3 +-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index f4397b2..363c7f8 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -405,6 +405,10 @@ class WGANGinkaLoss: class RNNGinkaLoss: def __init__(self, num_classes): self.num_classes = num_classes + weight = torch.ones(self.num_classes) + weight[0] = 0.3 + weight[1] = 0.5 + self.weight = weight pass def rnn_loss(self, fake, target): @@ -412,8 +416,5 @@ class RNNGinkaLoss: fake: [B, C, H, W] target: [B, H, W] """ - weight = torch.ones(self.num_classes) - weight[0] = 0.3 - weight[1] = 0.5 target = F.one_hot(target, num_classes=self.num_classes).float() - return F.cross_entropy(fake, target, label_smoothing=0.1, weight=weight) + return F.cross_entropy(fake, target, label_smoothing=0.1, weight=self.weight) diff --git a/ginka/generator/rnn.py b/ginka/generator/rnn.py index 50c206b..6c5fd1e 100644 --- a/ginka/generator/rnn.py +++ b/ginka/generator/rnn.py @@ -205,7 +205,7 @@ class GinkaRNNModel(nn.Module): map[:, y, x] = tile_id[:] now_tile = tile_id if use_self else target_map[:, y, x].detach() - return output_logits, map + return output_logits.permute(0, 3, 1, 2), map def print_memory(device, tag=""): if torch.cuda.is_available(): diff --git a/ginka/train_rnn.py b/ginka/train_rnn.py index 895e077..a7db7b5 100644 --- a/ginka/train_rnn.py +++ b/ginka/train_rnn.py @@ -86,7 +86,7 @@ def train(): optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4) scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6) - criterion = RNNGinkaLoss() + criterion = RNNGinkaLoss(32) # 用于生成图片 tile_dict = dict() @@ -158,7 +158,6 @@ def train(): val_cond = batch["val_cond"].to(device) target_map = batch["target_map"].to(device) - B, T = val_cond.shape fake_logits, fake_map = ginka_rnn(val_cond, target_map, False) val_loss_total += criterion.rnn_loss(fake_logits, target_map).detach()