diff --git a/ginka/dataset.py b/ginka/dataset.py index 0142f0b..866b2c7 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -31,8 +31,22 @@ def load_minamo_gan_data(data: list): res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True)) return res +def apply_curriculum_remove( + maps: torch.Tensor, + remove_classes: List[int], # 要移除的类别索引 +): + C, H, W = maps.shape + device = maps.device + removed_maps = maps.clone() + + remove_mask = removed_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0 + removed_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0 + removed_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地” + + return removed_maps.to(device) + def apply_curriculum_mask( - maps: torch.Tensor, # [B, C, H, W] + maps: torch.Tensor, # [C, H, W] mask_classes: List[int], # 要遮挡的类别索引 remove_classes: List[int], # 要移除的类别索引 mask_ratio: float # 遮挡比例 0~1 @@ -73,6 +87,42 @@ class GinkaWGANDataset(Dataset): def __len__(self): return len(self.data) + + def handle_stage1(self, target): + removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1) + removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2) + removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3) + + return removed1, masked1, removed2, masked2, removed3, masked3 + + def handle_stage2(self, target): + removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) + # 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可 + removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1)) + removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1)) + + if self.random_ratio > 0: + rd = random.uniform(0, self.random_ratio) + masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd) + masked2 = random_smooth_onehot(masked2, min_main=1 - rd, max_main=1.0, epsilon=rd) + masked3 = random_smooth_onehot(masked3, min_main=1 - rd, max_main=1.0, epsilon=rd) + + return removed1, masked1, removed2, masked2, removed3, masked3 + + def handle_stage3(self, target): + rd = random.uniform(0, self.random_ratio) + removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) + removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) + removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) + masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd) + return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target) + + def handle_stage4(self, target): + input1 = torch.rand((32, 13, 13)) + removed1 = apply_curriculum_remove(target, STAGE1_REMOVE) + removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) + removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) + return removed1, input1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target) def __getitem__(self, idx): item = self.data[idx] @@ -80,18 +130,16 @@ class GinkaWGANDataset(Dataset): target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] if self.train_stage == 1: - removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1) - removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2) - removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3) + return self.handle_stage1(target) + elif self.train_stage == 2: - removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) - removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 0.9)) - removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 0.9)) + return self.handle_stage2(target) - if self.random_ratio > 0: - removed1 = random_smooth_onehot(removed1, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio) - removed2 = random_smooth_onehot(removed2, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio) - removed3 = random_smooth_onehot(removed3, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio) + elif self.train_stage == 3: + return self.handle_stage3(target) - return removed1, masked1, removed2, masked2, removed3, masked3 + elif self.train_stage == 4: + return self.handle_stage4(target) + + raise RuntimeError(f"Invalid train stage: {self.train_stage}") \ No newline at end of file diff --git a/ginka/model/loss.py b/ginka/model/loss.py index ffce4cb..6ce7c1f 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -327,7 +327,7 @@ def immutable_penalty_loss( target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float() # 差异区域(模型试图改变的地方) - penalty = F.cross_entropy(input_mask, target_mask) + penalty = F.l1_loss(input_mask, target_mask) return penalty @@ -405,13 +405,13 @@ class WGANGinkaLoss: fake_scores, _, _ = critic(fake, fake_graph, stage) minamo_loss = -torch.mean(fake_scores) - ce_loss = F.cross_entropy(fake, real) + ce_loss = F.l1_loss(fake, real) immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) losses = [ minamo_loss * self.weight[0], - ce_loss * self.weight[1] / mask_ratio * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小 + ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小 immutable_loss * self.weight[2], constraint_loss * self.weight[3] ] @@ -423,4 +423,25 @@ class WGANGinkaLoss: # print(losses[2].item()) - return sum(losses), minamo_loss, ce_loss / mask_ratio, immutable_loss + return sum(losses), minamo_loss, ce_loss, immutable_loss + + def generator_loss_total(self, critic, stage, fake) -> torch.Tensor: + fake_graph = batch_convert_soft_map_to_graph(fake) + + fake_scores, _, _ = critic(fake, fake_graph, stage) + minamo_loss = -torch.mean(fake_scores) + immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage]) + constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) + + losses = [ + minamo_loss * self.weight[0], + immutable_loss * self.weight[2], + constraint_loss * self.weight[3] + ] + + if stage == 1: + # 第一个阶段检查入口存在性 + entrance_loss = entrance_constraint_loss(fake) + losses.append(entrance_loss * self.weight[4]) + + return sum(losses) diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 2410853..983566b 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -46,14 +46,19 @@ def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch. else: return fake1, fake2, fake3 -def gen_total(gen, input, detach=False) -> torch.Tensor: - fake1 = gen(input, 1) - fake2 = gen(fake1, 2) - fake3 = gen(fake2, 3) - if detach: - return fake3.detach() +def gen_total(gen, input, progress_detach=True, result_detach=False) -> torch.Tensor: + if progress_detach: + fake1 = gen(input.detach(), 1) + fake2 = gen(fake1.detach(), 2) + fake3 = gen(fake2.detach(), 3) else: - return fake3 + fake1 = gen(input, 1) + fake2 = gen(fake1, 2) + fake3 = gen(fake2, 3) + if result_detach: + return fake1.detach(), fake2.detach(), fake3.detach() + else: + return fake1, fake2, fake3 def train(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") @@ -67,6 +72,7 @@ def train(): train_stage = 1 mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练 random_ratio = 0 + stage3_epoch = 0 # 第三阶段 epoch 数,100 轮后进入第四阶段 ginka = GinkaModel() minamo = MinamoScoreModule() @@ -109,6 +115,9 @@ def train(): if data_ginka.get("random_ratio") is not None: random_ratio = data_ginka["random_ratio"] + if data_ginka.get("stage_epoch3") is not None: + stage3_epoch = data_ginka["stage_epoch3"] + if data_ginka.get("stage") is not None: train_stage = data_ginka["stage"] @@ -151,18 +160,19 @@ def train(): if train_stage == 1 or train_stage == 2: fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True) - loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1) - loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2) - loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3) - - dis_avg = (dis1 + dis2 + dis3) / 3.0 - loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0 + elif train_stage == 3 or train_stage == 4: + fake1, fake2, fake3 = gen_total(ginka, masked1, True, True) + + loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1) + loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2) + loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3) + + dis_avg = (dis1 + dis2 + dis3) / 3.0 + loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0 - # 反向传播 - loss_d_avg.backward() - elif train_stage == 3: - pass - + # 反向传播 + loss_d_avg.backward() + optimizer_minamo.step() loss_total_minamo += loss_d_avg.detach() @@ -188,8 +198,17 @@ def train(): loss_total_ginka += loss_g.detach() loss_ce_total += loss_ce.detach() - elif train_stage == 3: - pass + elif train_stage == 3 or train_stage == 4: + fake1, fake2, fake3 = gen_total(ginka, masked1, True, False) + + loss_g1 = criterion.generator_loss_total(minamo, 1, fake1) + loss_g2 = criterion.generator_loss_total(minamo, 2, fake2) + loss_g3 = criterion.generator_loss_total(minamo, 3, fake3) + + loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0 + loss_g.backward() + optimizer_ginka.step() + loss_total_ginka += loss_g.detach() avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps @@ -202,12 +221,14 @@ def train(): f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}" ) - if avg_loss_ce < 0.5: + if avg_loss_ce < 0.1: low_loss_epochs += 1 else: low_loss_epochs = 0 if low_loss_epochs >= 5 and train_stage == 2: + if random_ratio >= 0.5: + train_stage = 3 random_ratio += 0.1 random_ratio = min(random_ratio, 0.5) low_loss_epochs = 0 @@ -215,11 +236,20 @@ def train(): if low_loss_epochs >= 5 and train_stage == 1: if mask_ratio >= 0.9: train_stage = 2 - mask_ratio += 0.1 mask_ratio = min(mask_ratio, 0.9) low_loss_epochs = 0 + if train_stage == 3: + stage3_epoch += 1 + if stage3_epoch >= 100: + train_stage = 4 + stage3_epoch = 0 + + if train_stage >= 2: + # 第二阶段后 L1 损失不再应该生效 + mask_ratio = 1.0 + dataset.train_stage = 2 dataset_val.train_stage = 2 dataset.random_ratio = random_ratio @@ -235,8 +265,8 @@ def train(): else: g_steps = 1 - if avg_loss_ginka > 0 or avg_loss_minamo > 0: - c_steps = int(max(min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15), 1)) + if avg_loss_minamo > 0: + c_steps = int(min(5 + avg_loss_minamo * 5, 15)) else: c_steps = 5 @@ -251,6 +281,7 @@ def train(): "stage": train_stage, "mask_ratio": mask_ratio, "random_ratio": random_ratio, + "stage3_epoch": stage3_epoch, }, f"result/wgan/ginka-{epoch + 1}.pth") torch.save({ "model_state": minamo.state_dict(),