diff --git a/ginka/dataset.py b/ginka/dataset.py index acc7823..417a0e9 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -159,8 +159,11 @@ class GinkaWGANDataset(Dataset): item = self.data[idx] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] + C, H, W = target.shape tag_cond = torch.FloatTensor(item['tag']) val_cond = torch.FloatTensor(item['val']) + val_cond[9] = val_cond[9] / H / W + val_cond[10] = val_cond[10] / H / W if self.train_stage == 1: return self.handle_stage1(target, tag_cond, val_cond) diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 0e4edd6..7660d18 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -18,6 +18,20 @@ STAGE_ALLOWED = [ list(range(7, 26)) ] +DENSITY_MAP = [ + [1, *list(range(3, 30))], + [1], + [2], + [3, 4, 5, 6], + [26, 27, 28], + list(range(7, 26)), + list(range(10, 19)), + [19, 20, 21, 22], + [7, 8, 9], + [23, 24, 25], + [29, 30] +] + def get_not_allowed(classes: list[int], include_illegal=False): res = list() for num in range(0, CLASS_NUM): @@ -247,6 +261,20 @@ def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1): return wall_penalty +def compute_multi_density_loss(probs, target_densities): + """ + pred: [B, C, H, W] + target_densities: [B, N] - N 个目标类别密度 + class_indices: [N] - 对应类别通道索引 + """ + losses = [] + for i, classes in enumerate(DENSITY_MAP): + class_map = probs[:, classes, :, :] + pred_density = torch.mean(class_map, dim=(1, 2, 3)) + loss = F.mse_loss(pred_density, target_densities[:, i]) + losses.append(loss) + return sum(losses) + class GinkaLoss(nn.Module): def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]): """Ginka Model 损失函数部分 @@ -347,8 +375,8 @@ def immutable_penalty_loss( return penalty class WGANGinkaLoss: - def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2]): - # weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失 + def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2, 0.2]): + # weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失,密度损失 self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight @@ -417,6 +445,7 @@ class WGANGinkaLoss: ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小 immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) constraint_loss = inner_constraint_loss(probs_fake) + density_loss = compute_multi_density_loss(probs_fake, val_cond) fake_a, fake_b = fake.chunk(2, dim=0) @@ -426,6 +455,7 @@ class WGANGinkaLoss: immutable_loss * self.weight[2], constraint_loss * self.weight[3], -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], + density_loss * self.weight[6], ] if stage == 1: @@ -444,6 +474,7 @@ class WGANGinkaLoss: fake_scores, _, _ = critic(probs_fake, fake_graph, stage, tag_cond, val_cond) minamo_loss = -torch.mean(fake_scores) constraint_loss = inner_constraint_loss(probs_fake) + density_loss = compute_multi_density_loss(probs_fake, val_cond) fake_a, fake_b = fake.chunk(2, dim=0) @@ -451,6 +482,7 @@ class WGANGinkaLoss: minamo_loss * self.weight[0], constraint_loss * self.weight[3], -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], + density_loss * self.weight[6], ] if stage == 1: @@ -468,6 +500,7 @@ class WGANGinkaLoss: minamo_loss = -torch.mean(fake_scores) immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) constraint_loss = inner_constraint_loss(probs_fake) + density_loss = compute_multi_density_loss(probs_fake, val_cond) fake_a, fake_b = fake.chunk(2, dim=0) @@ -476,6 +509,7 @@ class WGANGinkaLoss: immutable_loss * self.weight[2], constraint_loss * self.weight[3], -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], + density_loss * self.weight[6], ] if stage == 1: diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 48ada16..1f3d419 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -151,16 +151,6 @@ def train(): optimizer_ginka.load_state_dict(data_ginka["optim_state"]) if data_minamo.get("optim_state") is not None: optimizer_minamo.load_state_dict(data_minamo["optim_state"]) - - dataset.train_stage = train_stage - dataset.mask_ratio1 = mask_ratio - dataset.mask_ratio2 = mask_ratio - dataset.mask_ratio3 = mask_ratio - - dataset_val.train_stage = train_stage - dataset_val.mask_ratio1 = mask_ratio - dataset_val.mask_ratio2 = mask_ratio - dataset_val.mask_ratio3 = mask_ratio print("Train from loaded state.") @@ -172,6 +162,16 @@ def train(): stage_epoch = 0 mask_ratio = 0.2 + dataset.train_stage = train_stage + dataset.mask_ratio1 = mask_ratio + dataset.mask_ratio2 = mask_ratio + dataset.mask_ratio3 = mask_ratio + + dataset_val.train_stage = train_stage + dataset_val.mask_ratio1 = mask_ratio + dataset_val.mask_ratio2 = mask_ratio + dataset_val.mask_ratio3 = mask_ratio + low_loss_epochs = 0 for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm): @@ -317,9 +317,9 @@ def train(): # 训练流程控制 - if mask_ratio < 0.5 and avg_loss_ce < 0.2: + if mask_ratio < 0.5 and avg_loss_ce < 0.5: low_loss_epochs += 1 - elif mask_ratio > 0.5 and avg_loss_ce < 0.3: + elif mask_ratio > 0.5 and avg_loss_ce < 0.5: low_loss_epochs += 1 else: low_loss_epochs = 0