perf: 修改损失值计算方式

This commit is contained in:
unanmed 2025-04-14 14:15:13 +08:00
parent f6b1ad6ebd
commit 447c28ff5e
3 changed files with 140 additions and 40 deletions

View File

@ -31,8 +31,22 @@ def load_minamo_gan_data(data: list):
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True)) res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
return res return res
def apply_curriculum_remove(
maps: torch.Tensor,
remove_classes: List[int], # 要移除的类别索引
):
C, H, W = maps.shape
device = maps.device
removed_maps = maps.clone()
remove_mask = removed_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
removed_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
removed_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
return removed_maps.to(device)
def apply_curriculum_mask( def apply_curriculum_mask(
maps: torch.Tensor, # [B, C, H, W] maps: torch.Tensor, # [C, H, W]
mask_classes: List[int], # 要遮挡的类别索引 mask_classes: List[int], # 要遮挡的类别索引
remove_classes: List[int], # 要移除的类别索引 remove_classes: List[int], # 要移除的类别索引
mask_ratio: float # 遮挡比例 0~1 mask_ratio: float # 遮挡比例 0~1
@ -73,6 +87,42 @@ class GinkaWGANDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def handle_stage1(self, target):
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
return removed1, masked1, removed2, masked2, removed3, masked3
def handle_stage2(self, target):
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
if self.random_ratio > 0:
rd = random.uniform(0, self.random_ratio)
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
masked2 = random_smooth_onehot(masked2, min_main=1 - rd, max_main=1.0, epsilon=rd)
masked3 = random_smooth_onehot(masked3, min_main=1 - rd, max_main=1.0, epsilon=rd)
return removed1, masked1, removed2, masked2, removed3, masked3
def handle_stage3(self, target):
rd = random.uniform(0, self.random_ratio)
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
def handle_stage4(self, target):
input1 = torch.rand((32, 13, 13))
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
return removed1, input1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
def __getitem__(self, idx): def __getitem__(self, idx):
item = self.data[idx] item = self.data[idx]
@ -80,18 +130,16 @@ class GinkaWGANDataset(Dataset):
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
if self.train_stage == 1: if self.train_stage == 1:
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1) return self.handle_stage1(target)
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
elif self.train_stage == 2: elif self.train_stage == 2:
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) return self.handle_stage2(target)
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 0.9))
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 0.9))
if self.random_ratio > 0: elif self.train_stage == 3:
removed1 = random_smooth_onehot(removed1, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio) return self.handle_stage3(target)
removed2 = random_smooth_onehot(removed2, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
removed3 = random_smooth_onehot(removed3, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
return removed1, masked1, removed2, masked2, removed3, masked3 elif self.train_stage == 4:
return self.handle_stage4(target)
raise RuntimeError(f"Invalid train stage: {self.train_stage}")

View File

@ -327,7 +327,7 @@ def immutable_penalty_loss(
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float() target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
# 差异区域(模型试图改变的地方) # 差异区域(模型试图改变的地方)
penalty = F.cross_entropy(input_mask, target_mask) penalty = F.l1_loss(input_mask, target_mask)
return penalty return penalty
@ -405,13 +405,13 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(fake, fake_graph, stage) fake_scores, _, _ = critic(fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores) minamo_loss = -torch.mean(fake_scores)
ce_loss = F.cross_entropy(fake, real) ce_loss = F.l1_loss(fake, real)
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake) constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
losses = [ losses = [
minamo_loss * self.weight[0], minamo_loss * self.weight[0],
ce_loss * self.weight[1] / mask_ratio * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小 ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
immutable_loss * self.weight[2], immutable_loss * self.weight[2],
constraint_loss * self.weight[3] constraint_loss * self.weight[3]
] ]
@ -423,4 +423,25 @@ class WGANGinkaLoss:
# print(losses[2].item()) # print(losses[2].item())
return sum(losses), minamo_loss, ce_loss / mask_ratio, immutable_loss return sum(losses), minamo_loss, ce_loss, immutable_loss
def generator_loss_total(self, critic, stage, fake) -> torch.Tensor:
fake_graph = batch_convert_soft_map_to_graph(fake)
fake_scores, _, _ = critic(fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
losses = [
minamo_loss * self.weight[0],
immutable_loss * self.weight[2],
constraint_loss * self.weight[3]
]
if stage == 1:
# 第一个阶段检查入口存在性
entrance_loss = entrance_constraint_loss(fake)
losses.append(entrance_loss * self.weight[4])
return sum(losses)

View File

@ -46,14 +46,19 @@ def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.
else: else:
return fake1, fake2, fake3 return fake1, fake2, fake3
def gen_total(gen, input, detach=False) -> torch.Tensor: def gen_total(gen, input, progress_detach=True, result_detach=False) -> torch.Tensor:
fake1 = gen(input, 1) if progress_detach:
fake2 = gen(fake1, 2) fake1 = gen(input.detach(), 1)
fake3 = gen(fake2, 3) fake2 = gen(fake1.detach(), 2)
if detach: fake3 = gen(fake2.detach(), 3)
return fake3.detach()
else: else:
return fake3 fake1 = gen(input, 1)
fake2 = gen(fake1, 2)
fake3 = gen(fake2, 3)
if result_detach:
return fake1.detach(), fake2.detach(), fake3.detach()
else:
return fake1, fake2, fake3
def train(): def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
@ -67,6 +72,7 @@ def train():
train_stage = 1 train_stage = 1
mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练 mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
random_ratio = 0 random_ratio = 0
stage3_epoch = 0 # 第三阶段 epoch 数100 轮后进入第四阶段
ginka = GinkaModel() ginka = GinkaModel()
minamo = MinamoScoreModule() minamo = MinamoScoreModule()
@ -109,6 +115,9 @@ def train():
if data_ginka.get("random_ratio") is not None: if data_ginka.get("random_ratio") is not None:
random_ratio = data_ginka["random_ratio"] random_ratio = data_ginka["random_ratio"]
if data_ginka.get("stage_epoch3") is not None:
stage3_epoch = data_ginka["stage_epoch3"]
if data_ginka.get("stage") is not None: if data_ginka.get("stage") is not None:
train_stage = data_ginka["stage"] train_stage = data_ginka["stage"]
@ -151,18 +160,19 @@ def train():
if train_stage == 1 or train_stage == 2: if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True) fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1) elif train_stage == 3 or train_stage == 4:
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2) fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
dis_avg = (dis1 + dis2 + dis3) / 3.0 loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0 loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
dis_avg = (dis1 + dis2 + dis3) / 3.0
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
# 反向传播 # 反向传播
loss_d_avg.backward() loss_d_avg.backward()
elif train_stage == 3:
pass
optimizer_minamo.step() optimizer_minamo.step()
loss_total_minamo += loss_d_avg.detach() loss_total_minamo += loss_d_avg.detach()
@ -188,8 +198,17 @@ def train():
loss_total_ginka += loss_g.detach() loss_total_ginka += loss_g.detach()
loss_ce_total += loss_ce.detach() loss_ce_total += loss_ce.detach()
elif train_stage == 3: elif train_stage == 3 or train_stage == 4:
pass fake1, fake2, fake3 = gen_total(ginka, masked1, True, False)
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
loss_g2 = criterion.generator_loss_total(minamo, 2, fake2)
loss_g3 = criterion.generator_loss_total(minamo, 3, fake3)
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g.detach()
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
@ -202,12 +221,14 @@ def train():
f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}" f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}"
) )
if avg_loss_ce < 0.5: if avg_loss_ce < 0.1:
low_loss_epochs += 1 low_loss_epochs += 1
else: else:
low_loss_epochs = 0 low_loss_epochs = 0
if low_loss_epochs >= 5 and train_stage == 2: if low_loss_epochs >= 5 and train_stage == 2:
if random_ratio >= 0.5:
train_stage = 3
random_ratio += 0.1 random_ratio += 0.1
random_ratio = min(random_ratio, 0.5) random_ratio = min(random_ratio, 0.5)
low_loss_epochs = 0 low_loss_epochs = 0
@ -215,11 +236,20 @@ def train():
if low_loss_epochs >= 5 and train_stage == 1: if low_loss_epochs >= 5 and train_stage == 1:
if mask_ratio >= 0.9: if mask_ratio >= 0.9:
train_stage = 2 train_stage = 2
mask_ratio += 0.1 mask_ratio += 0.1
mask_ratio = min(mask_ratio, 0.9) mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0 low_loss_epochs = 0
if train_stage == 3:
stage3_epoch += 1
if stage3_epoch >= 100:
train_stage = 4
stage3_epoch = 0
if train_stage >= 2:
# 第二阶段后 L1 损失不再应该生效
mask_ratio = 1.0
dataset.train_stage = 2 dataset.train_stage = 2
dataset_val.train_stage = 2 dataset_val.train_stage = 2
dataset.random_ratio = random_ratio dataset.random_ratio = random_ratio
@ -235,8 +265,8 @@ def train():
else: else:
g_steps = 1 g_steps = 1
if avg_loss_ginka > 0 or avg_loss_minamo > 0: if avg_loss_minamo > 0:
c_steps = int(max(min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15), 1)) c_steps = int(min(5 + avg_loss_minamo * 5, 15))
else: else:
c_steps = 5 c_steps = 5
@ -251,6 +281,7 @@ def train():
"stage": train_stage, "stage": train_stage,
"mask_ratio": mask_ratio, "mask_ratio": mask_ratio,
"random_ratio": random_ratio, "random_ratio": random_ratio,
"stage3_epoch": stage3_epoch,
}, f"result/wgan/ginka-{epoch + 1}.pth") }, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({ torch.save({
"model_state": minamo.state_dict(), "model_state": minamo.state_dict(),