From 5586ea1039a87ddba7f2e74f5f0d1f24dc9cc301 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Sat, 14 Jun 2025 15:06:09 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=BE=AE=E8=B0=83=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/critic/model.py | 10 ++++----- ginka/dataset.py | 6 ++--- ginka/generator/loss.py | 16 ++++++------- ginka/train_wgan.py | 50 ++++++++++++++++------------------------- 4 files changed, 35 insertions(+), 47 deletions(-) diff --git a/ginka/critic/model.py b/ginka/critic/model.py index 29620ee..04c2cf2 100644 --- a/ginka/critic/model.py +++ b/ginka/critic/model.py @@ -230,8 +230,8 @@ class MinamoModel2(nn.Module): super().__init__() self.cond = ConditionEncoder(64, 16, 256, 256) - self.conv1 = ConvFusionModule(tile_types, 256, 128, 13, 13) - self.conv2 = ConvFusionModule(128, 256, 256, 13, 13) + self.conv1 = ConvFusionModule(tile_types, 256, 256, 13, 13) + self.conv2 = ConvFusionModule(256, 512, 256, 13, 13) self.conv3 = ConvFusionModule(256, 512, 256, 13, 13) self.head0 = MinamoHead2(256, 256) # 随机头的判别头 @@ -239,9 +239,9 @@ class MinamoModel2(nn.Module): self.head2 = MinamoHead2(256, 256) self.head3 = MinamoHead2(256, 256) - # self.inject1 = ConditionInjector(256, 128) + # self.inject1 = ConditionInjector(256, 256) # self.inject2 = ConditionInjector(256, 256) - # self.inject3 = ConditionInjector(256, 256) + self.inject3 = ConditionInjector(256, 256) def forward(self, x, stage, tag_cond, val_cond): B, D = tag_cond.shape @@ -252,7 +252,7 @@ class MinamoModel2(nn.Module): x = self.conv2(x) # x = self.inject2(x, cond) x = self.conv3(x) - # x = self.inject3(x, cond) + x = self.inject3(x, cond) if stage == 0: score = self.head0(x, cond) diff --git a/ginka/dataset.py b/ginka/dataset.py index c7f6833..f4f8812 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -92,7 +92,7 @@ def apply_curriculum_wall_mask( removed_maps = masked_maps.clone() area = H * W * mask_ratio - l = math.ceil(math.sqrt(area)) + l = math.floor(math.sqrt(area)) nx = random.randint(0, W - l) ny = random.randint(0, H - l) masked_maps[mask_classes, nx:nx+l, ny:ny+l] = 0 @@ -155,7 +155,7 @@ class GinkaWGANDataset(Dataset): def handle_stage3(self, target, tag_cond, val_cond): # 第三阶段,联合生成,输入随机蒙版 - removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) + removed1 = apply_curriculum_remove(target, STAGE1_REMOVE) removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) rand = torch.rand(32, 32, 32, device=target.device) @@ -164,7 +164,7 @@ class GinkaWGANDataset(Dataset): "rand": rand, "real0": removed1, "real1": removed1, - "masked1": masked1, + "masked1": removed1, "real2": removed2, "masked2": torch.zeros_like(target), "real3": removed3, diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 2d772ad..5755c92 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -50,9 +50,9 @@ DENSITY_WEIGHTS = [ DENSITY_STAGE = [ [], - [1, 2, 10], - [1, 2, 3, 4, 10], - list(range(0, 11)) + [1, 2], + [1, 2, 3, 4], + list(range(0, 10)) ] def get_not_allowed(classes: list[int], include_illegal=False): @@ -232,7 +232,7 @@ def immutable_penalty_loss( target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float() # 差异区域(模型试图改变的地方) - penalty = F.cross_entropy(input_mask, target_mask) + penalty = torch.clamp(F.cross_entropy(input_mask, target_mask) - 0.2, min=0) return penalty @@ -314,7 +314,7 @@ class WGANGinkaLoss: probs_fake = F.softmax(fake, dim=1) fake_scores = critic(probs_fake, stage, tag_cond, val_cond) - minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage]) + minamo_loss = -torch.mean(fake_scores) ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小 immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage]) constraint_loss = inner_constraint_loss(probs_fake) @@ -342,7 +342,7 @@ class WGANGinkaLoss: probs_fake = F.softmax(fake, dim=1) fake_scores = critic(probs_fake, stage, tag_cond, val_cond) - minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage]) + minamo_loss = -torch.mean(fake_scores) illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage]) constraint_loss = inner_constraint_loss(probs_fake) density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage]) @@ -368,7 +368,7 @@ class WGANGinkaLoss: probs_fake = F.softmax(fake, dim=1) fake_scores = critic(probs_fake, stage, tag_cond, val_cond) - minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage]) + minamo_loss = -torch.mean(fake_scores) immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage]) constraint_loss = inner_constraint_loss(probs_fake) density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage]) @@ -397,7 +397,7 @@ class WGANGinkaLoss: losses = [ head_scores, input_head_illegal_loss(probs) * 50, - -js_divergence(probs_a, probs_b, softmax=False) * 0.1 + -js_divergence(probs_a, probs_b, softmax=False) * 0.5 ] return sum(losses) diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index d5930a2..70d6bd7 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -108,7 +108,7 @@ def train(): g_steps = 1 # 训练阶段 train_stage = 1 - mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练 + mask_ratio = 0.2 # 蒙版区域大小 stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程 total_epoch = 0 @@ -123,8 +123,8 @@ def train(): optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9)) - # scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs) - # scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs) + scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=100, T_mult=1) + scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=100, T_mult=1) criterion = WGANGinkaLoss() @@ -171,6 +171,7 @@ def train(): if args.tuning: train_stage = 1 curr_epoch = curr_epoch // 4 + first_curr = first_curr // 4 stage_epoch = 0 mask_ratio = 0.2 @@ -183,8 +184,6 @@ def train(): dataset_val.mask_ratio1 = mask_ratio dataset_val.mask_ratio2 = mask_ratio dataset_val.mask_ratio3 = mask_ratio - - low_loss_epochs = 0 for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm): loss_total_minamo = torch.Tensor([0]).to(device) @@ -265,7 +264,7 @@ def train(): loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond) loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0 - + if train_stage < 4: fake0 = F.softmax(ginka(rand, 0, tag_cond, val_cond), dim=1) @@ -282,9 +281,10 @@ def train(): avg_dis = dis_total.item() / len(dataloader) / c_steps tqdm.write( f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + - f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " + + f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " + f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " + - f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}" + f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " + + f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}" ) # 每若干轮输出一次图片,并保存检查点 @@ -364,46 +364,34 @@ def train(): # 训练流程控制 if train_stage >= 2: - if (epoch + 1) % 10 == 1: + # train_stage = 4 + if (epoch + 1) % 100 == 5: train_stage = 3 - elif (epoch + 1) % 10 == 3: + elif (epoch + 1) % 100 == 20: train_stage = 4 - elif (epoch + 1) % 10 == 0: + elif (epoch + 1) % 100 == 0: train_stage = 2 if train_stage == 1: if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \ (mask_ratio > 0.3 and stage_epoch >= curr_epoch): - if mask_ratio >= 0.9: - train_stage = 2 mask_ratio += 0.2 - mask_ratio = min(mask_ratio, 0.9) - low_loss_epochs = 0 + mask_ratio = min(mask_ratio, 0.8) + stage_epoch = 0 + if mask_ratio >= 0.8: + train_stage = 2 stage_epoch += 1 total_epoch += 1 - - # scheduler_ginka.step() - # scheduler_minamo.step() - - # if avg_dis < 0: - # g_steps = max(int(-avg_dis * 5), 1) - # else: - # g_steps = 1 - - # if avg_loss_ginka > 0 and epoch > 20 and not args.resume: - # g_steps += int(min(avg_loss_ginka * 5, 50)) - - # if avg_loss_minamo > 0: - # c_steps = int(min(5 + avg_loss_minamo * 5, 15)) - # else: - # c_steps = 5 dataset.train_stage = train_stage dataset_val.train_stage = train_stage dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio + + scheduler_ginka.step() + scheduler_minamo.step() print("Train ended.") torch.save({