feat: 损失值改进

This commit is contained in:
unanmed 2025-05-02 13:59:56 +08:00
parent 53041ab754
commit e3e496957c
3 changed files with 51 additions and 14 deletions

View File

@ -159,8 +159,11 @@ class GinkaWGANDataset(Dataset):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
C, H, W = target.shape
tag_cond = torch.FloatTensor(item['tag'])
val_cond = torch.FloatTensor(item['val'])
val_cond[9] = val_cond[9] / H / W
val_cond[10] = val_cond[10] / H / W
if self.train_stage == 1:
return self.handle_stage1(target, tag_cond, val_cond)

View File

@ -18,6 +18,20 @@ STAGE_ALLOWED = [
list(range(7, 26))
]
DENSITY_MAP = [
[1, *list(range(3, 30))],
[1],
[2],
[3, 4, 5, 6],
[26, 27, 28],
list(range(7, 26)),
list(range(10, 19)),
[19, 20, 21, 22],
[7, 8, 9],
[23, 24, 25],
[29, 30]
]
def get_not_allowed(classes: list[int], include_illegal=False):
res = list()
for num in range(0, CLASS_NUM):
@ -247,6 +261,20 @@ def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1):
return wall_penalty
def compute_multi_density_loss(probs, target_densities):
"""
pred: [B, C, H, W]
target_densities: [B, N] - N 个目标类别密度
class_indices: [N] - 对应类别通道索引
"""
losses = []
for i, classes in enumerate(DENSITY_MAP):
class_map = probs[:, classes, :, :]
pred_density = torch.mean(class_map, dim=(1, 2, 3))
loss = F.mse_loss(pred_density, target_densities[:, i])
losses.append(loss)
return sum(losses)
class GinkaLoss(nn.Module):
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]):
"""Ginka Model 损失函数部分
@ -347,8 +375,8 @@ def immutable_penalty_loss(
return penalty
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2]):
# weight: 判别器损失CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.2]):
# weight: 判别器损失CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失,密度损失
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
@ -417,6 +445,7 @@ class WGANGinkaLoss:
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond)
fake_a, fake_b = fake.chunk(2, dim=0)
@ -426,6 +455,7 @@ class WGANGinkaLoss:
immutable_loss * self.weight[2],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
density_loss * self.weight[6],
]
if stage == 1:
@ -444,6 +474,7 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores)
constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond)
fake_a, fake_b = fake.chunk(2, dim=0)
@ -451,6 +482,7 @@ class WGANGinkaLoss:
minamo_loss * self.weight[0],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
density_loss * self.weight[6],
]
if stage == 1:
@ -468,6 +500,7 @@ class WGANGinkaLoss:
minamo_loss = -torch.mean(fake_scores)
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond)
fake_a, fake_b = fake.chunk(2, dim=0)
@ -476,6 +509,7 @@ class WGANGinkaLoss:
immutable_loss * self.weight[2],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
density_loss * self.weight[6],
]
if stage == 1:

View File

@ -151,16 +151,6 @@ def train():
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
if data_minamo.get("optim_state") is not None:
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
dataset.train_stage = train_stage
dataset.mask_ratio1 = mask_ratio
dataset.mask_ratio2 = mask_ratio
dataset.mask_ratio3 = mask_ratio
dataset_val.train_stage = train_stage
dataset_val.mask_ratio1 = mask_ratio
dataset_val.mask_ratio2 = mask_ratio
dataset_val.mask_ratio3 = mask_ratio
print("Train from loaded state.")
@ -172,6 +162,16 @@ def train():
stage_epoch = 0
mask_ratio = 0.2
dataset.train_stage = train_stage
dataset.mask_ratio1 = mask_ratio
dataset.mask_ratio2 = mask_ratio
dataset.mask_ratio3 = mask_ratio
dataset_val.train_stage = train_stage
dataset_val.mask_ratio1 = mask_ratio
dataset_val.mask_ratio2 = mask_ratio
dataset_val.mask_ratio3 = mask_ratio
low_loss_epochs = 0
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
@ -317,9 +317,9 @@ def train():
# 训练流程控制
if mask_ratio < 0.5 and avg_loss_ce < 0.2:
if mask_ratio < 0.5 and avg_loss_ce < 0.5:
low_loss_epochs += 1
elif mask_ratio > 0.5 and avg_loss_ce < 0.3:
elif mask_ratio > 0.5 and avg_loss_ce < 0.5:
low_loss_epochs += 1
else:
low_loss_epochs = 0