mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
perf: 修改损失值计算方式
This commit is contained in:
parent
f6b1ad6ebd
commit
447c28ff5e
@ -31,8 +31,22 @@ def load_minamo_gan_data(data: list):
|
||||
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
|
||||
return res
|
||||
|
||||
def apply_curriculum_remove(
|
||||
maps: torch.Tensor,
|
||||
remove_classes: List[int], # 要移除的类别索引
|
||||
):
|
||||
C, H, W = maps.shape
|
||||
device = maps.device
|
||||
removed_maps = maps.clone()
|
||||
|
||||
remove_mask = removed_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
|
||||
removed_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
|
||||
removed_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
|
||||
|
||||
return removed_maps.to(device)
|
||||
|
||||
def apply_curriculum_mask(
|
||||
maps: torch.Tensor, # [B, C, H, W]
|
||||
maps: torch.Tensor, # [C, H, W]
|
||||
mask_classes: List[int], # 要遮挡的类别索引
|
||||
remove_classes: List[int], # 要移除的类别索引
|
||||
mask_ratio: float # 遮挡比例 0~1
|
||||
@ -73,6 +87,42 @@ class GinkaWGANDataset(Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def handle_stage1(self, target):
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
|
||||
|
||||
return removed1, masked1, removed2, masked2, removed3, masked3
|
||||
|
||||
def handle_stage2(self, target):
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
|
||||
|
||||
if self.random_ratio > 0:
|
||||
rd = random.uniform(0, self.random_ratio)
|
||||
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||
masked2 = random_smooth_onehot(masked2, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||
masked3 = random_smooth_onehot(masked3, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||
|
||||
return removed1, masked1, removed2, masked2, removed3, masked3
|
||||
|
||||
def handle_stage3(self, target):
|
||||
rd = random.uniform(0, self.random_ratio)
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||
return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
|
||||
|
||||
def handle_stage4(self, target):
|
||||
input1 = torch.rand((32, 13, 13))
|
||||
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
|
||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||
return removed1, input1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
@ -80,18 +130,16 @@ class GinkaWGANDataset(Dataset):
|
||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
|
||||
if self.train_stage == 1:
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
|
||||
return self.handle_stage1(target)
|
||||
|
||||
elif self.train_stage == 2:
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 0.9))
|
||||
return self.handle_stage2(target)
|
||||
|
||||
if self.random_ratio > 0:
|
||||
removed1 = random_smooth_onehot(removed1, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
||||
removed2 = random_smooth_onehot(removed2, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
||||
removed3 = random_smooth_onehot(removed3, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
||||
elif self.train_stage == 3:
|
||||
return self.handle_stage3(target)
|
||||
|
||||
return removed1, masked1, removed2, masked2, removed3, masked3
|
||||
elif self.train_stage == 4:
|
||||
return self.handle_stage4(target)
|
||||
|
||||
raise RuntimeError(f"Invalid train stage: {self.train_stage}")
|
||||
|
||||
@ -327,7 +327,7 @@ def immutable_penalty_loss(
|
||||
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
|
||||
|
||||
# 差异区域(模型试图改变的地方)
|
||||
penalty = F.cross_entropy(input_mask, target_mask)
|
||||
penalty = F.l1_loss(input_mask, target_mask)
|
||||
|
||||
return penalty
|
||||
|
||||
@ -405,13 +405,13 @@ class WGANGinkaLoss:
|
||||
|
||||
fake_scores, _, _ = critic(fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
ce_loss = F.cross_entropy(fake, real)
|
||||
ce_loss = F.l1_loss(fake, real)
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
ce_loss * self.weight[1] / mask_ratio * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
|
||||
ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3]
|
||||
]
|
||||
@ -423,4 +423,25 @@ class WGANGinkaLoss:
|
||||
|
||||
# print(losses[2].item())
|
||||
|
||||
return sum(losses), minamo_loss, ce_loss / mask_ratio, immutable_loss
|
||||
return sum(losses), minamo_loss, ce_loss, immutable_loss
|
||||
|
||||
def generator_loss_total(self, critic, stage, fake) -> torch.Tensor:
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake)
|
||||
|
||||
fake_scores, _, _ = critic(fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
immutable_loss = immutable_penalty_loss(fake, fake, STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3]
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
# 第一个阶段检查入口存在性
|
||||
entrance_loss = entrance_constraint_loss(fake)
|
||||
losses.append(entrance_loss * self.weight[4])
|
||||
|
||||
return sum(losses)
|
||||
|
||||
@ -46,14 +46,19 @@ def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.
|
||||
else:
|
||||
return fake1, fake2, fake3
|
||||
|
||||
def gen_total(gen, input, detach=False) -> torch.Tensor:
|
||||
fake1 = gen(input, 1)
|
||||
fake2 = gen(fake1, 2)
|
||||
fake3 = gen(fake2, 3)
|
||||
if detach:
|
||||
return fake3.detach()
|
||||
def gen_total(gen, input, progress_detach=True, result_detach=False) -> torch.Tensor:
|
||||
if progress_detach:
|
||||
fake1 = gen(input.detach(), 1)
|
||||
fake2 = gen(fake1.detach(), 2)
|
||||
fake3 = gen(fake2.detach(), 3)
|
||||
else:
|
||||
return fake3
|
||||
fake1 = gen(input, 1)
|
||||
fake2 = gen(fake1, 2)
|
||||
fake3 = gen(fake2, 3)
|
||||
if result_detach:
|
||||
return fake1.detach(), fake2.detach(), fake3.detach()
|
||||
else:
|
||||
return fake1, fake2, fake3
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
@ -67,6 +72,7 @@ def train():
|
||||
train_stage = 1
|
||||
mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
||||
random_ratio = 0
|
||||
stage3_epoch = 0 # 第三阶段 epoch 数,100 轮后进入第四阶段
|
||||
|
||||
ginka = GinkaModel()
|
||||
minamo = MinamoScoreModule()
|
||||
@ -109,6 +115,9 @@ def train():
|
||||
if data_ginka.get("random_ratio") is not None:
|
||||
random_ratio = data_ginka["random_ratio"]
|
||||
|
||||
if data_ginka.get("stage_epoch3") is not None:
|
||||
stage3_epoch = data_ginka["stage_epoch3"]
|
||||
|
||||
if data_ginka.get("stage") is not None:
|
||||
train_stage = data_ginka["stage"]
|
||||
|
||||
@ -151,18 +160,19 @@ def train():
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||
|
||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
|
||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
|
||||
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
|
||||
|
||||
dis_avg = (dis1 + dis2 + dis3) / 3.0
|
||||
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
|
||||
|
||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
|
||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
|
||||
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
|
||||
|
||||
dis_avg = (dis1 + dis2 + dis3) / 3.0
|
||||
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
|
||||
|
||||
# 反向传播
|
||||
loss_d_avg.backward()
|
||||
elif train_stage == 3:
|
||||
pass
|
||||
|
||||
# 反向传播
|
||||
loss_d_avg.backward()
|
||||
|
||||
optimizer_minamo.step()
|
||||
|
||||
loss_total_minamo += loss_d_avg.detach()
|
||||
@ -188,8 +198,17 @@ def train():
|
||||
loss_total_ginka += loss_g.detach()
|
||||
loss_ce_total += loss_ce.detach()
|
||||
|
||||
elif train_stage == 3:
|
||||
pass
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3 = gen_total(ginka, masked1, True, False)
|
||||
|
||||
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
|
||||
loss_g2 = criterion.generator_loss_total(minamo, 2, fake2)
|
||||
loss_g3 = criterion.generator_loss_total(minamo, 3, fake3)
|
||||
|
||||
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
||||
loss_g.backward()
|
||||
optimizer_ginka.step()
|
||||
loss_total_ginka += loss_g.detach()
|
||||
|
||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
||||
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
||||
@ -202,12 +221,14 @@ def train():
|
||||
f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}"
|
||||
)
|
||||
|
||||
if avg_loss_ce < 0.5:
|
||||
if avg_loss_ce < 0.1:
|
||||
low_loss_epochs += 1
|
||||
else:
|
||||
low_loss_epochs = 0
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 2:
|
||||
if random_ratio >= 0.5:
|
||||
train_stage = 3
|
||||
random_ratio += 0.1
|
||||
random_ratio = min(random_ratio, 0.5)
|
||||
low_loss_epochs = 0
|
||||
@ -215,11 +236,20 @@ def train():
|
||||
if low_loss_epochs >= 5 and train_stage == 1:
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
|
||||
mask_ratio += 0.1
|
||||
mask_ratio = min(mask_ratio, 0.9)
|
||||
low_loss_epochs = 0
|
||||
|
||||
if train_stage == 3:
|
||||
stage3_epoch += 1
|
||||
if stage3_epoch >= 100:
|
||||
train_stage = 4
|
||||
stage3_epoch = 0
|
||||
|
||||
if train_stage >= 2:
|
||||
# 第二阶段后 L1 损失不再应该生效
|
||||
mask_ratio = 1.0
|
||||
|
||||
dataset.train_stage = 2
|
||||
dataset_val.train_stage = 2
|
||||
dataset.random_ratio = random_ratio
|
||||
@ -235,8 +265,8 @@ def train():
|
||||
else:
|
||||
g_steps = 1
|
||||
|
||||
if avg_loss_ginka > 0 or avg_loss_minamo > 0:
|
||||
c_steps = int(max(min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15), 1))
|
||||
if avg_loss_minamo > 0:
|
||||
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
||||
else:
|
||||
c_steps = 5
|
||||
|
||||
@ -251,6 +281,7 @@ def train():
|
||||
"stage": train_stage,
|
||||
"mask_ratio": mask_ratio,
|
||||
"random_ratio": random_ratio,
|
||||
"stage3_epoch": stage3_epoch,
|
||||
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict(),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user