chore: 微调模型

This commit is contained in:
unanmed 2025-06-14 15:06:09 +08:00
parent fb0323d874
commit 5586ea1039
4 changed files with 35 additions and 47 deletions

View File

@ -230,8 +230,8 @@ class MinamoModel2(nn.Module):
super().__init__()
self.cond = ConditionEncoder(64, 16, 256, 256)
self.conv1 = ConvFusionModule(tile_types, 256, 128, 13, 13)
self.conv2 = ConvFusionModule(128, 256, 256, 13, 13)
self.conv1 = ConvFusionModule(tile_types, 256, 256, 13, 13)
self.conv2 = ConvFusionModule(256, 512, 256, 13, 13)
self.conv3 = ConvFusionModule(256, 512, 256, 13, 13)
self.head0 = MinamoHead2(256, 256) # 随机头的判别头
@ -239,9 +239,9 @@ class MinamoModel2(nn.Module):
self.head2 = MinamoHead2(256, 256)
self.head3 = MinamoHead2(256, 256)
# self.inject1 = ConditionInjector(256, 128)
# self.inject1 = ConditionInjector(256, 256)
# self.inject2 = ConditionInjector(256, 256)
# self.inject3 = ConditionInjector(256, 256)
self.inject3 = ConditionInjector(256, 256)
def forward(self, x, stage, tag_cond, val_cond):
B, D = tag_cond.shape
@ -252,7 +252,7 @@ class MinamoModel2(nn.Module):
x = self.conv2(x)
# x = self.inject2(x, cond)
x = self.conv3(x)
# x = self.inject3(x, cond)
x = self.inject3(x, cond)
if stage == 0:
score = self.head0(x, cond)

View File

@ -92,7 +92,7 @@ def apply_curriculum_wall_mask(
removed_maps = masked_maps.clone()
area = H * W * mask_ratio
l = math.ceil(math.sqrt(area))
l = math.floor(math.sqrt(area))
nx = random.randint(0, W - l)
ny = random.randint(0, H - l)
masked_maps[mask_classes, nx:nx+l, ny:ny+l] = 0
@ -155,7 +155,7 @@ class GinkaWGANDataset(Dataset):
def handle_stage3(self, target, tag_cond, val_cond):
# 第三阶段,联合生成,输入随机蒙版
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
rand = torch.rand(32, 32, 32, device=target.device)
@ -164,7 +164,7 @@ class GinkaWGANDataset(Dataset):
"rand": rand,
"real0": removed1,
"real1": removed1,
"masked1": masked1,
"masked1": removed1,
"real2": removed2,
"masked2": torch.zeros_like(target),
"real3": removed3,

View File

@ -50,9 +50,9 @@ DENSITY_WEIGHTS = [
DENSITY_STAGE = [
[],
[1, 2, 10],
[1, 2, 3, 4, 10],
list(range(0, 11))
[1, 2],
[1, 2, 3, 4],
list(range(0, 10))
]
def get_not_allowed(classes: list[int], include_illegal=False):
@ -232,7 +232,7 @@ def immutable_penalty_loss(
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
# 差异区域(模型试图改变的地方)
penalty = F.cross_entropy(input_mask, target_mask)
penalty = torch.clamp(F.cross_entropy(input_mask, target_mask) - 0.2, min=0)
return penalty
@ -314,7 +314,7 @@ class WGANGinkaLoss:
probs_fake = F.softmax(fake, dim=1)
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
minamo_loss = -torch.mean(fake_scores)
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
constraint_loss = inner_constraint_loss(probs_fake)
@ -342,7 +342,7 @@ class WGANGinkaLoss:
probs_fake = F.softmax(fake, dim=1)
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
minamo_loss = -torch.mean(fake_scores)
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
@ -368,7 +368,7 @@ class WGANGinkaLoss:
probs_fake = F.softmax(fake, dim=1)
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
minamo_loss = -torch.mean(fake_scores)
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
@ -397,7 +397,7 @@ class WGANGinkaLoss:
losses = [
head_scores,
input_head_illegal_loss(probs) * 50,
-js_divergence(probs_a, probs_b, softmax=False) * 0.1
-js_divergence(probs_a, probs_b, softmax=False) * 0.5
]
return sum(losses)

View File

@ -108,7 +108,7 @@ def train():
g_steps = 1
# 训练阶段
train_stage = 1
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
mask_ratio = 0.2 # 蒙版区域大小
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
total_epoch = 0
@ -123,8 +123,8 @@ def train():
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=100, T_mult=1)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=100, T_mult=1)
criterion = WGANGinkaLoss()
@ -171,6 +171,7 @@ def train():
if args.tuning:
train_stage = 1
curr_epoch = curr_epoch // 4
first_curr = first_curr // 4
stage_epoch = 0
mask_ratio = 0.2
@ -183,8 +184,6 @@ def train():
dataset_val.mask_ratio1 = mask_ratio
dataset_val.mask_ratio2 = mask_ratio
dataset_val.mask_ratio3 = mask_ratio
low_loss_epochs = 0
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
loss_total_minamo = torch.Tensor([0]).to(device)
@ -265,7 +264,7 @@ def train():
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond)
loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0
if train_stage < 4:
fake0 = F.softmax(ginka(rand, 0, tag_cond, val_cond), dim=1)
@ -282,9 +281,10 @@ def train():
avg_dis = dis_total.item() / len(dataloader) / c_steps
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
)
# 每若干轮输出一次图片,并保存检查点
@ -364,46 +364,34 @@ def train():
# 训练流程控制
if train_stage >= 2:
if (epoch + 1) % 10 == 1:
# train_stage = 4
if (epoch + 1) % 100 == 5:
train_stage = 3
elif (epoch + 1) % 10 == 3:
elif (epoch + 1) % 100 == 20:
train_stage = 4
elif (epoch + 1) % 10 == 0:
elif (epoch + 1) % 100 == 0:
train_stage = 2
if train_stage == 1:
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
(mask_ratio > 0.3 and stage_epoch >= curr_epoch):
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
mask_ratio = min(mask_ratio, 0.8)
stage_epoch = 0
if mask_ratio >= 0.8:
train_stage = 2
stage_epoch += 1
total_epoch += 1
# scheduler_ginka.step()
# scheduler_minamo.step()
# if avg_dis < 0:
# g_steps = max(int(-avg_dis * 5), 1)
# else:
# g_steps = 1
# if avg_loss_ginka > 0 and epoch > 20 and not args.resume:
# g_steps += int(min(avg_loss_ginka * 5, 50))
# if avg_loss_minamo > 0:
# c_steps = int(min(5 + avg_loss_minamo * 5, 15))
# else:
# c_steps = 5
dataset.train_stage = train_stage
dataset_val.train_stage = train_stage
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
scheduler_ginka.step()
scheduler_minamo.step()
print("Train ended.")
torch.save({