mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-18 07:31:11 +08:00
chore: 微调模型
This commit is contained in:
parent
fb0323d874
commit
5586ea1039
@ -230,8 +230,8 @@ class MinamoModel2(nn.Module):
|
||||
super().__init__()
|
||||
self.cond = ConditionEncoder(64, 16, 256, 256)
|
||||
|
||||
self.conv1 = ConvFusionModule(tile_types, 256, 128, 13, 13)
|
||||
self.conv2 = ConvFusionModule(128, 256, 256, 13, 13)
|
||||
self.conv1 = ConvFusionModule(tile_types, 256, 256, 13, 13)
|
||||
self.conv2 = ConvFusionModule(256, 512, 256, 13, 13)
|
||||
self.conv3 = ConvFusionModule(256, 512, 256, 13, 13)
|
||||
|
||||
self.head0 = MinamoHead2(256, 256) # 随机头的判别头
|
||||
@ -239,9 +239,9 @@ class MinamoModel2(nn.Module):
|
||||
self.head2 = MinamoHead2(256, 256)
|
||||
self.head3 = MinamoHead2(256, 256)
|
||||
|
||||
# self.inject1 = ConditionInjector(256, 128)
|
||||
# self.inject1 = ConditionInjector(256, 256)
|
||||
# self.inject2 = ConditionInjector(256, 256)
|
||||
# self.inject3 = ConditionInjector(256, 256)
|
||||
self.inject3 = ConditionInjector(256, 256)
|
||||
|
||||
def forward(self, x, stage, tag_cond, val_cond):
|
||||
B, D = tag_cond.shape
|
||||
@ -252,7 +252,7 @@ class MinamoModel2(nn.Module):
|
||||
x = self.conv2(x)
|
||||
# x = self.inject2(x, cond)
|
||||
x = self.conv3(x)
|
||||
# x = self.inject3(x, cond)
|
||||
x = self.inject3(x, cond)
|
||||
|
||||
if stage == 0:
|
||||
score = self.head0(x, cond)
|
||||
|
||||
@ -92,7 +92,7 @@ def apply_curriculum_wall_mask(
|
||||
removed_maps = masked_maps.clone()
|
||||
|
||||
area = H * W * mask_ratio
|
||||
l = math.ceil(math.sqrt(area))
|
||||
l = math.floor(math.sqrt(area))
|
||||
nx = random.randint(0, W - l)
|
||||
ny = random.randint(0, H - l)
|
||||
masked_maps[mask_classes, nx:nx+l, ny:ny+l] = 0
|
||||
@ -155,7 +155,7 @@ class GinkaWGANDataset(Dataset):
|
||||
|
||||
def handle_stage3(self, target, tag_cond, val_cond):
|
||||
# 第三阶段,联合生成,输入随机蒙版
|
||||
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
|
||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||
rand = torch.rand(32, 32, 32, device=target.device)
|
||||
@ -164,7 +164,7 @@ class GinkaWGANDataset(Dataset):
|
||||
"rand": rand,
|
||||
"real0": removed1,
|
||||
"real1": removed1,
|
||||
"masked1": masked1,
|
||||
"masked1": removed1,
|
||||
"real2": removed2,
|
||||
"masked2": torch.zeros_like(target),
|
||||
"real3": removed3,
|
||||
|
||||
@ -50,9 +50,9 @@ DENSITY_WEIGHTS = [
|
||||
|
||||
DENSITY_STAGE = [
|
||||
[],
|
||||
[1, 2, 10],
|
||||
[1, 2, 3, 4, 10],
|
||||
list(range(0, 11))
|
||||
[1, 2],
|
||||
[1, 2, 3, 4],
|
||||
list(range(0, 10))
|
||||
]
|
||||
|
||||
def get_not_allowed(classes: list[int], include_illegal=False):
|
||||
@ -232,7 +232,7 @@ def immutable_penalty_loss(
|
||||
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
|
||||
|
||||
# 差异区域(模型试图改变的地方)
|
||||
penalty = F.cross_entropy(input_mask, target_mask)
|
||||
penalty = torch.clamp(F.cross_entropy(input_mask, target_mask) - 0.2, min=0)
|
||||
|
||||
return penalty
|
||||
|
||||
@ -314,7 +314,7 @@ class WGANGinkaLoss:
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
|
||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
@ -342,7 +342,7 @@ class WGANGinkaLoss:
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
|
||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
||||
@ -368,7 +368,7 @@ class WGANGinkaLoss:
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
|
||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
||||
@ -397,7 +397,7 @@ class WGANGinkaLoss:
|
||||
losses = [
|
||||
head_scores,
|
||||
input_head_illegal_loss(probs) * 50,
|
||||
-js_divergence(probs_a, probs_b, softmax=False) * 0.1
|
||||
-js_divergence(probs_a, probs_b, softmax=False) * 0.5
|
||||
]
|
||||
|
||||
return sum(losses)
|
||||
|
||||
@ -108,7 +108,7 @@ def train():
|
||||
g_steps = 1
|
||||
# 训练阶段
|
||||
train_stage = 1
|
||||
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
||||
mask_ratio = 0.2 # 蒙版区域大小
|
||||
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
||||
total_epoch = 0
|
||||
|
||||
@ -123,8 +123,8 @@ def train():
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
|
||||
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
||||
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=100, T_mult=1)
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=100, T_mult=1)
|
||||
|
||||
criterion = WGANGinkaLoss()
|
||||
|
||||
@ -171,6 +171,7 @@ def train():
|
||||
if args.tuning:
|
||||
train_stage = 1
|
||||
curr_epoch = curr_epoch // 4
|
||||
first_curr = first_curr // 4
|
||||
stage_epoch = 0
|
||||
mask_ratio = 0.2
|
||||
|
||||
@ -183,8 +184,6 @@ def train():
|
||||
dataset_val.mask_ratio1 = mask_ratio
|
||||
dataset_val.mask_ratio2 = mask_ratio
|
||||
dataset_val.mask_ratio3 = mask_ratio
|
||||
|
||||
low_loss_epochs = 0
|
||||
|
||||
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
||||
loss_total_minamo = torch.Tensor([0]).to(device)
|
||||
@ -265,7 +264,7 @@ def train():
|
||||
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond)
|
||||
|
||||
loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0
|
||||
|
||||
|
||||
if train_stage < 4:
|
||||
fake0 = F.softmax(ginka(rand, 0, tag_cond, val_cond), dim=1)
|
||||
|
||||
@ -282,9 +281,10 @@ def train():
|
||||
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||
f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
||||
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
||||
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
|
||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
|
||||
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||
)
|
||||
|
||||
# 每若干轮输出一次图片,并保存检查点
|
||||
@ -364,46 +364,34 @@ def train():
|
||||
# 训练流程控制
|
||||
|
||||
if train_stage >= 2:
|
||||
if (epoch + 1) % 10 == 1:
|
||||
# train_stage = 4
|
||||
if (epoch + 1) % 100 == 5:
|
||||
train_stage = 3
|
||||
elif (epoch + 1) % 10 == 3:
|
||||
elif (epoch + 1) % 100 == 20:
|
||||
train_stage = 4
|
||||
elif (epoch + 1) % 10 == 0:
|
||||
elif (epoch + 1) % 100 == 0:
|
||||
train_stage = 2
|
||||
|
||||
if train_stage == 1:
|
||||
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
|
||||
(mask_ratio > 0.3 and stage_epoch >= curr_epoch):
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
mask_ratio += 0.2
|
||||
mask_ratio = min(mask_ratio, 0.9)
|
||||
low_loss_epochs = 0
|
||||
mask_ratio = min(mask_ratio, 0.8)
|
||||
|
||||
stage_epoch = 0
|
||||
if mask_ratio >= 0.8:
|
||||
train_stage = 2
|
||||
|
||||
stage_epoch += 1
|
||||
total_epoch += 1
|
||||
|
||||
# scheduler_ginka.step()
|
||||
# scheduler_minamo.step()
|
||||
|
||||
# if avg_dis < 0:
|
||||
# g_steps = max(int(-avg_dis * 5), 1)
|
||||
# else:
|
||||
# g_steps = 1
|
||||
|
||||
# if avg_loss_ginka > 0 and epoch > 20 and not args.resume:
|
||||
# g_steps += int(min(avg_loss_ginka * 5, 50))
|
||||
|
||||
# if avg_loss_minamo > 0:
|
||||
# c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
||||
# else:
|
||||
# c_steps = 5
|
||||
|
||||
dataset.train_stage = train_stage
|
||||
dataset_val.train_stage = train_stage
|
||||
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
||||
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
||||
|
||||
scheduler_ginka.step()
|
||||
scheduler_minamo.step()
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
|
||||
Loading…
Reference in New Issue
Block a user