mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 改进训练流程
This commit is contained in:
parent
87016c67e8
commit
a94b07bda8
@ -106,11 +106,9 @@ class GinkaWGANDataset(Dataset):
|
||||
|
||||
def handle_stage3(self, target):
|
||||
# 第三阶段,联合生成,输入随机蒙版
|
||||
rd = random.uniform(0, self.random_ratio)
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||
return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
|
||||
|
||||
def handle_stage4(self, target):
|
||||
|
||||
@ -2,10 +2,10 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class RandomInputHead(nn.Module):
|
||||
def __init__(self, in_size=(32, 32), out_size=(32, 32)):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(1, 32, 3, padding=1, padding_mode='replicate'),
|
||||
nn.Conv2d(32, 32, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(32),
|
||||
nn.ELU(),
|
||||
|
||||
@ -18,6 +18,7 @@ class RandomInputHead(nn.Module):
|
||||
nn.ELU(),
|
||||
)
|
||||
self.out_conv = nn.Sequential(
|
||||
nn.AdaptiveMaxPool2d((13, 13)),
|
||||
nn.Conv2d(128, 32, 1),
|
||||
)
|
||||
|
||||
|
||||
@ -419,18 +419,18 @@ class WGANGinkaLoss:
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio)
|
||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
|
||||
# fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
ce_loss * self.weight[1], # 蒙版越大,交叉熵损失权重越小
|
||||
ce_loss * self.weight[1],
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3],
|
||||
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
@ -450,12 +450,12 @@ class WGANGinkaLoss:
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
|
||||
# fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
constraint_loss * self.weight[3],
|
||||
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
@ -474,13 +474,13 @@ class WGANGinkaLoss:
|
||||
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
|
||||
# fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3],
|
||||
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
|
||||
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .unet import GinkaUNet
|
||||
from .output import GinkaOutput
|
||||
from .input import GinkaInput
|
||||
from .input import GinkaInput, RandomInputHead
|
||||
|
||||
def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||
@ -13,21 +13,20 @@ class GinkaModel(nn.Module):
|
||||
"""Ginka Model 模型定义部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.head = RandomInputHead()
|
||||
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
|
||||
self.unet = GinkaUNet(32, base_ch, base_ch)
|
||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||
|
||||
def forward(self, x, stage):
|
||||
"""
|
||||
Args:
|
||||
x: 参考地图的特征向量
|
||||
Returns:
|
||||
logits: 输出logits [BS, num_classes, H, W]
|
||||
"""
|
||||
x = self.input(x)
|
||||
def forward(self, x, stage, random=False):
|
||||
if random:
|
||||
x_in = F.softmax(self.head(x))
|
||||
else:
|
||||
x_in = x
|
||||
x = self.input(x_in)
|
||||
x = self.unet(x)
|
||||
x = self.output(x, stage)
|
||||
return x
|
||||
return x, x_in
|
||||
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -33,31 +33,32 @@ def parse_arguments():
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--checkpoint", type=int, default=5)
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
parser.add_argument("--curr_epoch", type=int, default=20) # 课程学习至少多少 epoch
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
fake1: torch.Tensor = gen(masked1, 1)
|
||||
fake2: torch.Tensor = gen(masked2, 2)
|
||||
fake3: torch.Tensor = gen(masked3, 3)
|
||||
fake1, _ = gen(masked1, 1)
|
||||
fake2, _ = gen(masked2, 2)
|
||||
fake3, _ = gen(masked3, 3)
|
||||
if detach:
|
||||
return fake1.detach(), fake2.detach(), fake3.detach()
|
||||
else:
|
||||
return fake1, fake2, fake3
|
||||
|
||||
def gen_total(gen, input, progress_detach=True, result_detach=False) -> torch.Tensor:
|
||||
def gen_total(gen, input, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
|
||||
if progress_detach:
|
||||
fake1 = gen(input.detach(), 1)
|
||||
fake2 = gen(fake1.detach(), 2)
|
||||
fake3 = gen(fake2.detach(), 3)
|
||||
fake1, x_in = gen(input.detach(), 1, random)
|
||||
fake2, _ = gen(F.softmax(fake1.detach()), 2)
|
||||
fake3, _ = gen(F.softmax(fake2.detach()), 3)
|
||||
else:
|
||||
fake1 = gen(input, 1)
|
||||
fake2 = gen(fake1, 2)
|
||||
fake3 = gen(fake2, 3)
|
||||
fake1, x_in = gen(input, 1, random)
|
||||
fake2, _ = gen(F.softmax(fake1), 2)
|
||||
fake3, _ = gen(F.softmax(fake2), 3)
|
||||
if result_detach:
|
||||
return fake1.detach(), fake2.detach(), fake3.detach()
|
||||
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
|
||||
else:
|
||||
return fake1, fake2, fake3
|
||||
return fake1, fake2, fake3, x_in
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
@ -169,14 +170,9 @@ def train():
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||
|
||||
elif train_stage == 3:
|
||||
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
|
||||
|
||||
elif train_stage == 4:
|
||||
input = F.softmax(ginka_head(masked1), dim=1)
|
||||
fake1, fake2, fake3 = gen_total(ginka, input, True, True)
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3, _ = gen_total(ginka, masked1, True, True, train_stage == 4)
|
||||
|
||||
|
||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
|
||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
|
||||
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
|
||||
@ -214,9 +210,7 @@ def train():
|
||||
loss_ce_total += loss_ce.detach()
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
|
||||
|
||||
fake1, fake2, fake3 = gen_total(ginka, input, True, False)
|
||||
fake1, fake2, fake3, x_in = gen_total(ginka, input, True, False)
|
||||
|
||||
if train_stage == 3:
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, input)
|
||||
@ -225,6 +219,10 @@ def train():
|
||||
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
|
||||
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2)
|
||||
|
||||
if train_stage == 4:
|
||||
loss_head = criterion.generator_input_head_loss(x_in)
|
||||
loss_head.backward()
|
||||
|
||||
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
||||
loss_g.backward()
|
||||
optimizer_ginka.step()
|
||||
@ -246,24 +244,30 @@ def train():
|
||||
else:
|
||||
low_loss_epochs = 0
|
||||
|
||||
if low_loss_epochs >= 3 and train_stage == 1:
|
||||
# 训练流程控制
|
||||
|
||||
if low_loss_epochs >= 3 and train_stage == 1 and stage_epoch >= args.curr_epoch:
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
stage_epoch = 0
|
||||
mask_ratio += 0.2
|
||||
mask_ratio = min(mask_ratio, 0.9)
|
||||
low_loss_epochs = 0
|
||||
stage_epoch = 0
|
||||
|
||||
if train_stage == 3 or train_stage == 2:
|
||||
if (train_stage == 3 or train_stage == 2) and not last_stage:
|
||||
if stage_epoch >= 25:
|
||||
train_stage += 1
|
||||
stage_epoch = 0
|
||||
|
||||
if train_stage >= 3:
|
||||
if train_stage == 4:
|
||||
last_stage = True
|
||||
|
||||
if train_stage >= 3 or last_stage:
|
||||
# 第三阶段后交叉熵损失不再应该生效
|
||||
mask_ratio = 1.0
|
||||
|
||||
if last_stage:
|
||||
mask_ratio = 1.0
|
||||
if train_stage == 2 and stage_epoch % 5 == 0:
|
||||
train_stage = 4
|
||||
|
||||
@ -317,7 +321,7 @@ def train():
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
|
||||
fake1, fake2, fake3 = gen_total(ginka, input, True, True)
|
||||
fake1, fake2, fake3, _ = gen_total(ginka, input, True, True)
|
||||
|
||||
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
||||
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user