perf: 修改损失值计算方式

This commit is contained in:
unanmed 2025-04-14 14:15:13 +08:00
parent f6b1ad6ebd
commit 447c28ff5e
3 changed files with 140 additions and 40 deletions

View File

@ -31,8 +31,22 @@ def load_minamo_gan_data(data: list):
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
return res
def apply_curriculum_remove(
maps: torch.Tensor,
remove_classes: List[int], # 要移除的类别索引
):
C, H, W = maps.shape
device = maps.device
removed_maps = maps.clone()
remove_mask = removed_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
removed_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
removed_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
return removed_maps.to(device)
def apply_curriculum_mask(
maps: torch.Tensor, # [B, C, H, W]
maps: torch.Tensor, # [C, H, W]
mask_classes: List[int], # 要遮挡的类别索引
remove_classes: List[int], # 要移除的类别索引
mask_ratio: float # 遮挡比例 0~1
@ -73,6 +87,42 @@ class GinkaWGANDataset(Dataset):
def __len__(self):
return len(self.data)
def handle_stage1(self, target):
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
return removed1, masked1, removed2, masked2, removed3, masked3
def handle_stage2(self, target):
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
if self.random_ratio > 0:
rd = random.uniform(0, self.random_ratio)
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
masked2 = random_smooth_onehot(masked2, min_main=1 - rd, max_main=1.0, epsilon=rd)
masked3 = random_smooth_onehot(masked3, min_main=1 - rd, max_main=1.0, epsilon=rd)
return removed1, masked1, removed2, masked2, removed3, masked3
def handle_stage3(self, target):
rd = random.uniform(0, self.random_ratio)
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
def handle_stage4(self, target):
input1 = torch.rand((32, 13, 13))
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
return removed1, input1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
def __getitem__(self, idx):
item = self.data[idx]
@ -80,18 +130,16 @@ class GinkaWGANDataset(Dataset):
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
if self.train_stage == 1:
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
return self.handle_stage1(target)
elif self.train_stage == 2:
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 0.9))
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 0.9))
return self.handle_stage2(target)
if self.random_ratio > 0:
removed1 = random_smooth_onehot(removed1, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
removed2 = random_smooth_onehot(removed2, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
removed3 = random_smooth_onehot(removed3, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
elif self.train_stage == 3:
return self.handle_stage3(target)
return removed1, masked1, removed2, masked2, removed3, masked3
elif self.train_stage == 4:
return self.handle_stage4(target)
raise RuntimeError(f"Invalid train stage: {self.train_stage}")

View File

@ -327,7 +327,7 @@ def immutable_penalty_loss(
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
# 差异区域(模型试图改变的地方)
penalty = F.cross_entropy(input_mask, target_mask)
penalty = F.l1_loss(input_mask, target_mask)
return penalty
@ -405,13 +405,13 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
ce_loss = F.cross_entropy(fake, real)
ce_loss = F.l1_loss(fake, real)
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
losses = [
minamo_loss * self.weight[0],
ce_loss * self.weight[1] / mask_ratio * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
immutable_loss * self.weight[2],
constraint_loss * self.weight[3]
]
@ -423,4 +423,25 @@ class WGANGinkaLoss:
# print(losses[2].item())
return sum(losses), minamo_loss, ce_loss / mask_ratio, immutable_loss
return sum(losses), minamo_loss, ce_loss, immutable_loss
def generator_loss_total(self, critic, stage, fake) -> torch.Tensor:
fake_graph = batch_convert_soft_map_to_graph(fake)
fake_scores, _, _ = critic(fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
losses = [
minamo_loss * self.weight[0],
immutable_loss * self.weight[2],
constraint_loss * self.weight[3]
]
if stage == 1:
# 第一个阶段检查入口存在性
entrance_loss = entrance_constraint_loss(fake)
losses.append(entrance_loss * self.weight[4])
return sum(losses)

View File

@ -46,14 +46,19 @@ def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.
else:
return fake1, fake2, fake3
def gen_total(gen, input, detach=False) -> torch.Tensor:
fake1 = gen(input, 1)
fake2 = gen(fake1, 2)
fake3 = gen(fake2, 3)
if detach:
return fake3.detach()
def gen_total(gen, input, progress_detach=True, result_detach=False) -> torch.Tensor:
if progress_detach:
fake1 = gen(input.detach(), 1)
fake2 = gen(fake1.detach(), 2)
fake3 = gen(fake2.detach(), 3)
else:
return fake3
fake1 = gen(input, 1)
fake2 = gen(fake1, 2)
fake3 = gen(fake2, 3)
if result_detach:
return fake1.detach(), fake2.detach(), fake3.detach()
else:
return fake1, fake2, fake3
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
@ -67,6 +72,7 @@ def train():
train_stage = 1
mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
random_ratio = 0
stage3_epoch = 0 # 第三阶段 epoch 数100 轮后进入第四阶段
ginka = GinkaModel()
minamo = MinamoScoreModule()
@ -109,6 +115,9 @@ def train():
if data_ginka.get("random_ratio") is not None:
random_ratio = data_ginka["random_ratio"]
if data_ginka.get("stage_epoch3") is not None:
stage3_epoch = data_ginka["stage_epoch3"]
if data_ginka.get("stage") is not None:
train_stage = data_ginka["stage"]
@ -151,18 +160,19 @@ def train():
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
dis_avg = (dis1 + dis2 + dis3) / 3.0
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
dis_avg = (dis1 + dis2 + dis3) / 3.0
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
# 反向传播
loss_d_avg.backward()
elif train_stage == 3:
pass
# 反向传播
loss_d_avg.backward()
optimizer_minamo.step()
loss_total_minamo += loss_d_avg.detach()
@ -188,8 +198,17 @@ def train():
loss_total_ginka += loss_g.detach()
loss_ce_total += loss_ce.detach()
elif train_stage == 3:
pass
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3 = gen_total(ginka, masked1, True, False)
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
loss_g2 = criterion.generator_loss_total(minamo, 2, fake2)
loss_g3 = criterion.generator_loss_total(minamo, 3, fake3)
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g.detach()
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
@ -202,12 +221,14 @@ def train():
f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}"
)
if avg_loss_ce < 0.5:
if avg_loss_ce < 0.1:
low_loss_epochs += 1
else:
low_loss_epochs = 0
if low_loss_epochs >= 5 and train_stage == 2:
if random_ratio >= 0.5:
train_stage = 3
random_ratio += 0.1
random_ratio = min(random_ratio, 0.5)
low_loss_epochs = 0
@ -215,11 +236,20 @@ def train():
if low_loss_epochs >= 5 and train_stage == 1:
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.1
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
if train_stage == 3:
stage3_epoch += 1
if stage3_epoch >= 100:
train_stage = 4
stage3_epoch = 0
if train_stage >= 2:
# 第二阶段后 L1 损失不再应该生效
mask_ratio = 1.0
dataset.train_stage = 2
dataset_val.train_stage = 2
dataset.random_ratio = random_ratio
@ -235,8 +265,8 @@ def train():
else:
g_steps = 1
if avg_loss_ginka > 0 or avg_loss_minamo > 0:
c_steps = int(max(min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15), 1))
if avg_loss_minamo > 0:
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
else:
c_steps = 5
@ -251,6 +281,7 @@ def train():
"stage": train_stage,
"mask_ratio": mask_ratio,
"random_ratio": random_ratio,
"stage3_epoch": stage3_epoch,
}, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({
"model_state": minamo.state_dict(),