mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 12:21:11 +08:00
chore: 微调模型
This commit is contained in:
parent
fb0323d874
commit
5586ea1039
@ -230,8 +230,8 @@ class MinamoModel2(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.cond = ConditionEncoder(64, 16, 256, 256)
|
self.cond = ConditionEncoder(64, 16, 256, 256)
|
||||||
|
|
||||||
self.conv1 = ConvFusionModule(tile_types, 256, 128, 13, 13)
|
self.conv1 = ConvFusionModule(tile_types, 256, 256, 13, 13)
|
||||||
self.conv2 = ConvFusionModule(128, 256, 256, 13, 13)
|
self.conv2 = ConvFusionModule(256, 512, 256, 13, 13)
|
||||||
self.conv3 = ConvFusionModule(256, 512, 256, 13, 13)
|
self.conv3 = ConvFusionModule(256, 512, 256, 13, 13)
|
||||||
|
|
||||||
self.head0 = MinamoHead2(256, 256) # 随机头的判别头
|
self.head0 = MinamoHead2(256, 256) # 随机头的判别头
|
||||||
@ -239,9 +239,9 @@ class MinamoModel2(nn.Module):
|
|||||||
self.head2 = MinamoHead2(256, 256)
|
self.head2 = MinamoHead2(256, 256)
|
||||||
self.head3 = MinamoHead2(256, 256)
|
self.head3 = MinamoHead2(256, 256)
|
||||||
|
|
||||||
# self.inject1 = ConditionInjector(256, 128)
|
# self.inject1 = ConditionInjector(256, 256)
|
||||||
# self.inject2 = ConditionInjector(256, 256)
|
# self.inject2 = ConditionInjector(256, 256)
|
||||||
# self.inject3 = ConditionInjector(256, 256)
|
self.inject3 = ConditionInjector(256, 256)
|
||||||
|
|
||||||
def forward(self, x, stage, tag_cond, val_cond):
|
def forward(self, x, stage, tag_cond, val_cond):
|
||||||
B, D = tag_cond.shape
|
B, D = tag_cond.shape
|
||||||
@ -252,7 +252,7 @@ class MinamoModel2(nn.Module):
|
|||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
# x = self.inject2(x, cond)
|
# x = self.inject2(x, cond)
|
||||||
x = self.conv3(x)
|
x = self.conv3(x)
|
||||||
# x = self.inject3(x, cond)
|
x = self.inject3(x, cond)
|
||||||
|
|
||||||
if stage == 0:
|
if stage == 0:
|
||||||
score = self.head0(x, cond)
|
score = self.head0(x, cond)
|
||||||
|
|||||||
@ -92,7 +92,7 @@ def apply_curriculum_wall_mask(
|
|||||||
removed_maps = masked_maps.clone()
|
removed_maps = masked_maps.clone()
|
||||||
|
|
||||||
area = H * W * mask_ratio
|
area = H * W * mask_ratio
|
||||||
l = math.ceil(math.sqrt(area))
|
l = math.floor(math.sqrt(area))
|
||||||
nx = random.randint(0, W - l)
|
nx = random.randint(0, W - l)
|
||||||
ny = random.randint(0, H - l)
|
ny = random.randint(0, H - l)
|
||||||
masked_maps[mask_classes, nx:nx+l, ny:ny+l] = 0
|
masked_maps[mask_classes, nx:nx+l, ny:ny+l] = 0
|
||||||
@ -155,7 +155,7 @@ class GinkaWGANDataset(Dataset):
|
|||||||
|
|
||||||
def handle_stage3(self, target, tag_cond, val_cond):
|
def handle_stage3(self, target, tag_cond, val_cond):
|
||||||
# 第三阶段,联合生成,输入随机蒙版
|
# 第三阶段,联合生成,输入随机蒙版
|
||||||
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
|
||||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||||
rand = torch.rand(32, 32, 32, device=target.device)
|
rand = torch.rand(32, 32, 32, device=target.device)
|
||||||
@ -164,7 +164,7 @@ class GinkaWGANDataset(Dataset):
|
|||||||
"rand": rand,
|
"rand": rand,
|
||||||
"real0": removed1,
|
"real0": removed1,
|
||||||
"real1": removed1,
|
"real1": removed1,
|
||||||
"masked1": masked1,
|
"masked1": removed1,
|
||||||
"real2": removed2,
|
"real2": removed2,
|
||||||
"masked2": torch.zeros_like(target),
|
"masked2": torch.zeros_like(target),
|
||||||
"real3": removed3,
|
"real3": removed3,
|
||||||
|
|||||||
@ -50,9 +50,9 @@ DENSITY_WEIGHTS = [
|
|||||||
|
|
||||||
DENSITY_STAGE = [
|
DENSITY_STAGE = [
|
||||||
[],
|
[],
|
||||||
[1, 2, 10],
|
[1, 2],
|
||||||
[1, 2, 3, 4, 10],
|
[1, 2, 3, 4],
|
||||||
list(range(0, 11))
|
list(range(0, 10))
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_not_allowed(classes: list[int], include_illegal=False):
|
def get_not_allowed(classes: list[int], include_illegal=False):
|
||||||
@ -232,7 +232,7 @@ def immutable_penalty_loss(
|
|||||||
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
|
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
|
||||||
|
|
||||||
# 差异区域(模型试图改变的地方)
|
# 差异区域(模型试图改变的地方)
|
||||||
penalty = F.cross_entropy(input_mask, target_mask)
|
penalty = torch.clamp(F.cross_entropy(input_mask, target_mask) - 0.2, min=0)
|
||||||
|
|
||||||
return penalty
|
return penalty
|
||||||
|
|
||||||
@ -314,7 +314,7 @@ class WGANGinkaLoss:
|
|||||||
probs_fake = F.softmax(fake, dim=1)
|
probs_fake = F.softmax(fake, dim=1)
|
||||||
|
|
||||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
||||||
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
||||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
||||||
constraint_loss = inner_constraint_loss(probs_fake)
|
constraint_loss = inner_constraint_loss(probs_fake)
|
||||||
@ -342,7 +342,7 @@ class WGANGinkaLoss:
|
|||||||
probs_fake = F.softmax(fake, dim=1)
|
probs_fake = F.softmax(fake, dim=1)
|
||||||
|
|
||||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
||||||
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
|
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
|
||||||
constraint_loss = inner_constraint_loss(probs_fake)
|
constraint_loss = inner_constraint_loss(probs_fake)
|
||||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
||||||
@ -368,7 +368,7 @@ class WGANGinkaLoss:
|
|||||||
probs_fake = F.softmax(fake, dim=1)
|
probs_fake = F.softmax(fake, dim=1)
|
||||||
|
|
||||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
||||||
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
immutable_loss = immutable_penalty_loss(fake, input, STAGE_CHANGEABLE[stage])
|
||||||
constraint_loss = inner_constraint_loss(probs_fake)
|
constraint_loss = inner_constraint_loss(probs_fake)
|
||||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
||||||
@ -397,7 +397,7 @@ class WGANGinkaLoss:
|
|||||||
losses = [
|
losses = [
|
||||||
head_scores,
|
head_scores,
|
||||||
input_head_illegal_loss(probs) * 50,
|
input_head_illegal_loss(probs) * 50,
|
||||||
-js_divergence(probs_a, probs_b, softmax=False) * 0.1
|
-js_divergence(probs_a, probs_b, softmax=False) * 0.5
|
||||||
]
|
]
|
||||||
|
|
||||||
return sum(losses)
|
return sum(losses)
|
||||||
|
|||||||
@ -108,7 +108,7 @@ def train():
|
|||||||
g_steps = 1
|
g_steps = 1
|
||||||
# 训练阶段
|
# 训练阶段
|
||||||
train_stage = 1
|
train_stage = 1
|
||||||
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
mask_ratio = 0.2 # 蒙版区域大小
|
||||||
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
||||||
total_epoch = 0
|
total_epoch = 0
|
||||||
|
|
||||||
@ -123,8 +123,8 @@ def train():
|
|||||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||||
|
|
||||||
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=100, T_mult=1)
|
||||||
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
|
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=100, T_mult=1)
|
||||||
|
|
||||||
criterion = WGANGinkaLoss()
|
criterion = WGANGinkaLoss()
|
||||||
|
|
||||||
@ -171,6 +171,7 @@ def train():
|
|||||||
if args.tuning:
|
if args.tuning:
|
||||||
train_stage = 1
|
train_stage = 1
|
||||||
curr_epoch = curr_epoch // 4
|
curr_epoch = curr_epoch // 4
|
||||||
|
first_curr = first_curr // 4
|
||||||
stage_epoch = 0
|
stage_epoch = 0
|
||||||
mask_ratio = 0.2
|
mask_ratio = 0.2
|
||||||
|
|
||||||
@ -184,8 +185,6 @@ def train():
|
|||||||
dataset_val.mask_ratio2 = mask_ratio
|
dataset_val.mask_ratio2 = mask_ratio
|
||||||
dataset_val.mask_ratio3 = mask_ratio
|
dataset_val.mask_ratio3 = mask_ratio
|
||||||
|
|
||||||
low_loss_epochs = 0
|
|
||||||
|
|
||||||
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
||||||
loss_total_minamo = torch.Tensor([0]).to(device)
|
loss_total_minamo = torch.Tensor([0]).to(device)
|
||||||
loss_total_ginka = torch.Tensor([0]).to(device)
|
loss_total_ginka = torch.Tensor([0]).to(device)
|
||||||
@ -282,9 +281,10 @@ def train():
|
|||||||
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||||
f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
f"E: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
||||||
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
||||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
|
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | " +
|
||||||
|
f"LR: {optimizer_ginka.param_groups[0]['lr']:.6f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 每若干轮输出一次图片,并保存检查点
|
# 每若干轮输出一次图片,并保存检查点
|
||||||
@ -364,47 +364,35 @@ def train():
|
|||||||
# 训练流程控制
|
# 训练流程控制
|
||||||
|
|
||||||
if train_stage >= 2:
|
if train_stage >= 2:
|
||||||
if (epoch + 1) % 10 == 1:
|
# train_stage = 4
|
||||||
|
if (epoch + 1) % 100 == 5:
|
||||||
train_stage = 3
|
train_stage = 3
|
||||||
elif (epoch + 1) % 10 == 3:
|
elif (epoch + 1) % 100 == 20:
|
||||||
train_stage = 4
|
train_stage = 4
|
||||||
elif (epoch + 1) % 10 == 0:
|
elif (epoch + 1) % 100 == 0:
|
||||||
train_stage = 2
|
train_stage = 2
|
||||||
|
|
||||||
if train_stage == 1:
|
if train_stage == 1:
|
||||||
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
|
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
|
||||||
(mask_ratio > 0.3 and stage_epoch >= curr_epoch):
|
(mask_ratio > 0.3 and stage_epoch >= curr_epoch):
|
||||||
if mask_ratio >= 0.9:
|
|
||||||
train_stage = 2
|
|
||||||
mask_ratio += 0.2
|
mask_ratio += 0.2
|
||||||
mask_ratio = min(mask_ratio, 0.9)
|
mask_ratio = min(mask_ratio, 0.8)
|
||||||
low_loss_epochs = 0
|
|
||||||
stage_epoch = 0
|
stage_epoch = 0
|
||||||
|
if mask_ratio >= 0.8:
|
||||||
|
train_stage = 2
|
||||||
|
|
||||||
stage_epoch += 1
|
stage_epoch += 1
|
||||||
total_epoch += 1
|
total_epoch += 1
|
||||||
|
|
||||||
# scheduler_ginka.step()
|
|
||||||
# scheduler_minamo.step()
|
|
||||||
|
|
||||||
# if avg_dis < 0:
|
|
||||||
# g_steps = max(int(-avg_dis * 5), 1)
|
|
||||||
# else:
|
|
||||||
# g_steps = 1
|
|
||||||
|
|
||||||
# if avg_loss_ginka > 0 and epoch > 20 and not args.resume:
|
|
||||||
# g_steps += int(min(avg_loss_ginka * 5, 50))
|
|
||||||
|
|
||||||
# if avg_loss_minamo > 0:
|
|
||||||
# c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
|
||||||
# else:
|
|
||||||
# c_steps = 5
|
|
||||||
|
|
||||||
dataset.train_stage = train_stage
|
dataset.train_stage = train_stage
|
||||||
dataset_val.train_stage = train_stage
|
dataset_val.train_stage = train_stage
|
||||||
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
||||||
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
||||||
|
|
||||||
|
scheduler_ginka.step()
|
||||||
|
scheduler_minamo.step()
|
||||||
|
|
||||||
print("Train ended.")
|
print("Train ended.")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": ginka.state_dict(),
|
"model_state": ginka.state_dict(),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user