mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 损失值改进
This commit is contained in:
parent
53041ab754
commit
e3e496957c
@ -159,8 +159,11 @@ class GinkaWGANDataset(Dataset):
|
||||
item = self.data[idx]
|
||||
|
||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
C, H, W = target.shape
|
||||
tag_cond = torch.FloatTensor(item['tag'])
|
||||
val_cond = torch.FloatTensor(item['val'])
|
||||
val_cond[9] = val_cond[9] / H / W
|
||||
val_cond[10] = val_cond[10] / H / W
|
||||
|
||||
if self.train_stage == 1:
|
||||
return self.handle_stage1(target, tag_cond, val_cond)
|
||||
|
||||
@ -18,6 +18,20 @@ STAGE_ALLOWED = [
|
||||
list(range(7, 26))
|
||||
]
|
||||
|
||||
DENSITY_MAP = [
|
||||
[1, *list(range(3, 30))],
|
||||
[1],
|
||||
[2],
|
||||
[3, 4, 5, 6],
|
||||
[26, 27, 28],
|
||||
list(range(7, 26)),
|
||||
list(range(10, 19)),
|
||||
[19, 20, 21, 22],
|
||||
[7, 8, 9],
|
||||
[23, 24, 25],
|
||||
[29, 30]
|
||||
]
|
||||
|
||||
def get_not_allowed(classes: list[int], include_illegal=False):
|
||||
res = list()
|
||||
for num in range(0, CLASS_NUM):
|
||||
@ -247,6 +261,20 @@ def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1):
|
||||
|
||||
return wall_penalty
|
||||
|
||||
def compute_multi_density_loss(probs, target_densities):
|
||||
"""
|
||||
pred: [B, C, H, W]
|
||||
target_densities: [B, N] - N 个目标类别密度
|
||||
class_indices: [N] - 对应类别通道索引
|
||||
"""
|
||||
losses = []
|
||||
for i, classes in enumerate(DENSITY_MAP):
|
||||
class_map = probs[:, classes, :, :]
|
||||
pred_density = torch.mean(class_map, dim=(1, 2, 3))
|
||||
loss = F.mse_loss(pred_density, target_densities[:, i])
|
||||
losses.append(loss)
|
||||
return sum(losses)
|
||||
|
||||
class GinkaLoss(nn.Module):
|
||||
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]):
|
||||
"""Ginka Model 损失函数部分
|
||||
@ -347,8 +375,8 @@ def immutable_penalty_loss(
|
||||
return penalty
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2]):
|
||||
# weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.2]):
|
||||
# weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失,密度损失
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
|
||||
@ -417,6 +445,7 @@ class WGANGinkaLoss:
|
||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
@ -426,6 +455,7 @@ class WGANGinkaLoss:
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
density_loss * self.weight[6],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
@ -444,6 +474,7 @@ class WGANGinkaLoss:
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
@ -451,6 +482,7 @@ class WGANGinkaLoss:
|
||||
minamo_loss * self.weight[0],
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
density_loss * self.weight[6],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
@ -468,6 +500,7 @@ class WGANGinkaLoss:
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
@ -476,6 +509,7 @@ class WGANGinkaLoss:
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
density_loss * self.weight[6],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
|
||||
@ -152,16 +152,6 @@ def train():
|
||||
if data_minamo.get("optim_state") is not None:
|
||||
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
|
||||
|
||||
dataset.train_stage = train_stage
|
||||
dataset.mask_ratio1 = mask_ratio
|
||||
dataset.mask_ratio2 = mask_ratio
|
||||
dataset.mask_ratio3 = mask_ratio
|
||||
|
||||
dataset_val.train_stage = train_stage
|
||||
dataset_val.mask_ratio1 = mask_ratio
|
||||
dataset_val.mask_ratio2 = mask_ratio
|
||||
dataset_val.mask_ratio3 = mask_ratio
|
||||
|
||||
print("Train from loaded state.")
|
||||
|
||||
curr_epoch = args.curr_epoch
|
||||
@ -172,6 +162,16 @@ def train():
|
||||
stage_epoch = 0
|
||||
mask_ratio = 0.2
|
||||
|
||||
dataset.train_stage = train_stage
|
||||
dataset.mask_ratio1 = mask_ratio
|
||||
dataset.mask_ratio2 = mask_ratio
|
||||
dataset.mask_ratio3 = mask_ratio
|
||||
|
||||
dataset_val.train_stage = train_stage
|
||||
dataset_val.mask_ratio1 = mask_ratio
|
||||
dataset_val.mask_ratio2 = mask_ratio
|
||||
dataset_val.mask_ratio3 = mask_ratio
|
||||
|
||||
low_loss_epochs = 0
|
||||
|
||||
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
||||
@ -317,9 +317,9 @@ def train():
|
||||
|
||||
# 训练流程控制
|
||||
|
||||
if mask_ratio < 0.5 and avg_loss_ce < 0.2:
|
||||
if mask_ratio < 0.5 and avg_loss_ce < 0.5:
|
||||
low_loss_epochs += 1
|
||||
elif mask_ratio > 0.5 and avg_loss_ce < 0.3:
|
||||
elif mask_ratio > 0.5 and avg_loss_ce < 0.5:
|
||||
low_loss_epochs += 1
|
||||
else:
|
||||
low_loss_epochs = 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user