From a94b07bda85f47c858a073e199703b7ec5992b21 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sun, 20 Apr 2025 22:06:20 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=94=B9=E8=BF=9B=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E6=B5=81=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/dataset.py | 2 -- ginka/model/input.py | 5 ++-- ginka/model/loss.py | 16 ++++++------ ginka/model/model.py | 19 +++++++-------- ginka/train_wgan.py | 58 +++++++++++++++++++++++--------------------- 5 files changed, 51 insertions(+), 49 deletions(-) diff --git a/ginka/dataset.py b/ginka/dataset.py index 10b7d6e..f87f7a2 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -106,11 +106,9 @@ class GinkaWGANDataset(Dataset): def handle_stage3(self, target): # 第三阶段,联合生成,输入随机蒙版 - rd = random.uniform(0, self.random_ratio) removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) - masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd) return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target) def handle_stage4(self, target): diff --git a/ginka/model/input.py b/ginka/model/input.py index 554aec3..2281cc4 100644 --- a/ginka/model/input.py +++ b/ginka/model/input.py @@ -2,10 +2,10 @@ import torch import torch.nn as nn class RandomInputHead(nn.Module): - def __init__(self, in_size=(32, 32), out_size=(32, 32)): + def __init__(self): super().__init__() self.conv = nn.Sequential( - nn.Conv2d(1, 32, 3, padding=1, padding_mode='replicate'), + nn.Conv2d(32, 32, 3, padding=1, padding_mode='replicate'), nn.InstanceNorm2d(32), nn.ELU(), @@ -18,6 +18,7 @@ class RandomInputHead(nn.Module): nn.ELU(), ) self.out_conv = nn.Sequential( + nn.AdaptiveMaxPool2d((13, 13)), nn.Conv2d(128, 32, 1), ) diff --git a/ginka/model/loss.py b/ginka/model/loss.py index b5ad2dd..c283218 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -419,18 +419,18 @@ class WGANGinkaLoss: fake_scores, _, _ = critic(probs_fake, fake_graph, stage) minamo_loss = -torch.mean(fake_scores) - ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) + ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小 immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage]) constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake) - # fake_a, fake_b = fake.chunk(2, dim=0) + fake_a, fake_b = fake.chunk(2, dim=0) losses = [ minamo_loss * self.weight[0], - ce_loss * self.weight[1], # 蒙版越大,交叉熵损失权重越小 + ce_loss * self.weight[1], immutable_loss * self.weight[2], constraint_loss * self.weight[3], - # -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], + -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], ] if stage == 1: @@ -450,12 +450,12 @@ class WGANGinkaLoss: minamo_loss = -torch.mean(fake_scores) constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake) - # fake_a, fake_b = fake.chunk(2, dim=0) + fake_a, fake_b = fake.chunk(2, dim=0) losses = [ minamo_loss * self.weight[0], constraint_loss * self.weight[3], - # -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], + -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], ] if stage == 1: @@ -474,13 +474,13 @@ class WGANGinkaLoss: immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage]) constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake) - # fake_a, fake_b = fake.chunk(2, dim=0) + fake_a, fake_b = fake.chunk(2, dim=0) losses = [ minamo_loss * self.weight[0], immutable_loss * self.weight[2], constraint_loss * self.weight[3], - # -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], + -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], ] if stage == 1: diff --git a/ginka/model/model.py b/ginka/model/model.py index 260d2fe..02ee7fc 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -3,7 +3,7 @@ import torch.nn as nn import torch.nn.functional as F from .unet import GinkaUNet from .output import GinkaOutput -from .input import GinkaInput +from .input import GinkaInput, RandomInputHead def print_memory(tag=""): print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") @@ -13,21 +13,20 @@ class GinkaModel(nn.Module): """Ginka Model 模型定义部分 """ super().__init__() + self.head = RandomInputHead() self.input = GinkaInput(32, 32, (13, 13), (32, 32)) self.unet = GinkaUNet(32, base_ch, base_ch) self.output = GinkaOutput(base_ch, out_ch, (13, 13)) - def forward(self, x, stage): - """ - Args: - x: 参考地图的特征向量 - Returns: - logits: 输出logits [BS, num_classes, H, W] - """ - x = self.input(x) + def forward(self, x, stage, random=False): + if random: + x_in = F.softmax(self.head(x)) + else: + x_in = x + x = self.input(x_in) x = self.unet(x) x = self.output(x, stage) - return x + return x, x_in # 检查显存占用 if __name__ == "__main__": diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index edf603c..a6c2f6d 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -33,31 +33,32 @@ def parse_arguments(): parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--checkpoint", type=int, default=5) parser.add_argument("--load_optim", type=bool, default=True) + parser.add_argument("--curr_epoch", type=int, default=20) # 课程学习至少多少 epoch args = parser.parse_args() return args def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - fake1: torch.Tensor = gen(masked1, 1) - fake2: torch.Tensor = gen(masked2, 2) - fake3: torch.Tensor = gen(masked3, 3) + fake1, _ = gen(masked1, 1) + fake2, _ = gen(masked2, 2) + fake3, _ = gen(masked3, 3) if detach: return fake1.detach(), fake2.detach(), fake3.detach() else: return fake1, fake2, fake3 -def gen_total(gen, input, progress_detach=True, result_detach=False) -> torch.Tensor: +def gen_total(gen, input, progress_detach=True, result_detach=False, random=False) -> torch.Tensor: if progress_detach: - fake1 = gen(input.detach(), 1) - fake2 = gen(fake1.detach(), 2) - fake3 = gen(fake2.detach(), 3) + fake1, x_in = gen(input.detach(), 1, random) + fake2, _ = gen(F.softmax(fake1.detach()), 2) + fake3, _ = gen(F.softmax(fake2.detach()), 3) else: - fake1 = gen(input, 1) - fake2 = gen(fake1, 2) - fake3 = gen(fake2, 3) + fake1, x_in = gen(input, 1, random) + fake2, _ = gen(F.softmax(fake1), 2) + fake3, _ = gen(F.softmax(fake2), 3) if result_detach: - return fake1.detach(), fake2.detach(), fake3.detach() + return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach() else: - return fake1, fake2, fake3 + return fake1, fake2, fake3, x_in def train(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") @@ -169,14 +170,9 @@ def train(): if train_stage == 1 or train_stage == 2: fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True) - elif train_stage == 3: - fake1, fake2, fake3 = gen_total(ginka, masked1, True, True) - - elif train_stage == 4: - input = F.softmax(ginka_head(masked1), dim=1) - fake1, fake2, fake3 = gen_total(ginka, input, True, True) + elif train_stage == 3 or train_stage == 4: + fake1, fake2, fake3, _ = gen_total(ginka, masked1, True, True, train_stage == 4) - loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1) loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2) loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3) @@ -214,9 +210,7 @@ def train(): loss_ce_total += loss_ce.detach() elif train_stage == 3 or train_stage == 4: - input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1) - - fake1, fake2, fake3 = gen_total(ginka, input, True, False) + fake1, fake2, fake3, x_in = gen_total(ginka, input, True, False) if train_stage == 3: loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, input) @@ -225,6 +219,10 @@ def train(): loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1) loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2) + if train_stage == 4: + loss_head = criterion.generator_input_head_loss(x_in) + loss_head.backward() + loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0 loss_g.backward() optimizer_ginka.step() @@ -246,24 +244,30 @@ def train(): else: low_loss_epochs = 0 - if low_loss_epochs >= 3 and train_stage == 1: + # 训练流程控制 + + if low_loss_epochs >= 3 and train_stage == 1 and stage_epoch >= args.curr_epoch: if mask_ratio >= 0.9: train_stage = 2 - stage_epoch = 0 mask_ratio += 0.2 mask_ratio = min(mask_ratio, 0.9) low_loss_epochs = 0 + stage_epoch = 0 - if train_stage == 3 or train_stage == 2: + if (train_stage == 3 or train_stage == 2) and not last_stage: if stage_epoch >= 25: train_stage += 1 stage_epoch = 0 - if train_stage >= 3: + if train_stage == 4: + last_stage = True + + if train_stage >= 3 or last_stage: # 第三阶段后交叉熵损失不再应该生效 mask_ratio = 1.0 if last_stage: + mask_ratio = 1.0 if train_stage == 2 and stage_epoch % 5 == 0: train_stage = 4 @@ -317,7 +321,7 @@ def train(): elif train_stage == 3 or train_stage == 4: input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1) - fake1, fake2, fake3 = gen_total(ginka, input, True, True) + fake1, fake2, fake3, _ = gen_total(ginka, input, True, True) fake1 = torch.argmax(fake1, dim=1).cpu().numpy() fake2 = torch.argmax(fake2, dim=1).cpu().numpy()