feat: 改进训练流程

This commit is contained in:
unanmed 2025-04-20 22:06:20 +08:00
parent 87016c67e8
commit a94b07bda8
5 changed files with 51 additions and 49 deletions

View File

@ -106,11 +106,9 @@ class GinkaWGANDataset(Dataset):
def handle_stage3(self, target):
# 第三阶段,联合生成,输入随机蒙版
rd = random.uniform(0, self.random_ratio)
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
def handle_stage4(self, target):

View File

@ -2,10 +2,10 @@ import torch
import torch.nn as nn
class RandomInputHead(nn.Module):
def __init__(self, in_size=(32, 32), out_size=(32, 32)):
def __init__(self):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1, padding_mode='replicate'),
nn.Conv2d(32, 32, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(32),
nn.ELU(),
@ -18,6 +18,7 @@ class RandomInputHead(nn.Module):
nn.ELU(),
)
self.out_conv = nn.Sequential(
nn.AdaptiveMaxPool2d((13, 13)),
nn.Conv2d(128, 32, 1),
)

View File

@ -419,18 +419,18 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio)
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
# fake_a, fake_b = fake.chunk(2, dim=0)
fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
ce_loss * self.weight[1], # 蒙版越大,交叉熵损失权重越小
ce_loss * self.weight[1],
immutable_loss * self.weight[2],
constraint_loss * self.weight[3],
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
]
if stage == 1:
@ -450,12 +450,12 @@ class WGANGinkaLoss:
minamo_loss = -torch.mean(fake_scores)
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
# fake_a, fake_b = fake.chunk(2, dim=0)
fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
constraint_loss * self.weight[3],
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
]
if stage == 1:
@ -474,13 +474,13 @@ class WGANGinkaLoss:
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
# fake_a, fake_b = fake.chunk(2, dim=0)
fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
immutable_loss * self.weight[2],
constraint_loss * self.weight[3],
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
]
if stage == 1:

View File

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .unet import GinkaUNet
from .output import GinkaOutput
from .input import GinkaInput
from .input import GinkaInput, RandomInputHead
def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
@ -13,21 +13,20 @@ class GinkaModel(nn.Module):
"""Ginka Model 模型定义部分
"""
super().__init__()
self.head = RandomInputHead()
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
self.unet = GinkaUNet(32, base_ch, base_ch)
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
def forward(self, x, stage):
"""
Args:
x: 参考地图的特征向量
Returns:
logits: 输出logits [BS, num_classes, H, W]
"""
x = self.input(x)
def forward(self, x, stage, random=False):
if random:
x_in = F.softmax(self.head(x))
else:
x_in = x
x = self.input(x_in)
x = self.unet(x)
x = self.output(x, stage)
return x
return x, x_in
# 检查显存占用
if __name__ == "__main__":

View File

@ -33,31 +33,32 @@ def parse_arguments():
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
parser.add_argument("--curr_epoch", type=int, default=20) # 课程学习至少多少 epoch
args = parser.parse_args()
return args
def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
fake1: torch.Tensor = gen(masked1, 1)
fake2: torch.Tensor = gen(masked2, 2)
fake3: torch.Tensor = gen(masked3, 3)
fake1, _ = gen(masked1, 1)
fake2, _ = gen(masked2, 2)
fake3, _ = gen(masked3, 3)
if detach:
return fake1.detach(), fake2.detach(), fake3.detach()
else:
return fake1, fake2, fake3
def gen_total(gen, input, progress_detach=True, result_detach=False) -> torch.Tensor:
def gen_total(gen, input, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
if progress_detach:
fake1 = gen(input.detach(), 1)
fake2 = gen(fake1.detach(), 2)
fake3 = gen(fake2.detach(), 3)
fake1, x_in = gen(input.detach(), 1, random)
fake2, _ = gen(F.softmax(fake1.detach()), 2)
fake3, _ = gen(F.softmax(fake2.detach()), 3)
else:
fake1 = gen(input, 1)
fake2 = gen(fake1, 2)
fake3 = gen(fake2, 3)
fake1, x_in = gen(input, 1, random)
fake2, _ = gen(F.softmax(fake1), 2)
fake3, _ = gen(F.softmax(fake2), 3)
if result_detach:
return fake1.detach(), fake2.detach(), fake3.detach()
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
else:
return fake1, fake2, fake3
return fake1, fake2, fake3, x_in
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
@ -169,13 +170,8 @@ def train():
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
elif train_stage == 3:
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
elif train_stage == 4:
input = F.softmax(ginka_head(masked1), dim=1)
fake1, fake2, fake3 = gen_total(ginka, input, True, True)
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, _ = gen_total(ginka, masked1, True, True, train_stage == 4)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
@ -214,9 +210,7 @@ def train():
loss_ce_total += loss_ce.detach()
elif train_stage == 3 or train_stage == 4:
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
fake1, fake2, fake3 = gen_total(ginka, input, True, False)
fake1, fake2, fake3, x_in = gen_total(ginka, input, True, False)
if train_stage == 3:
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, input)
@ -225,6 +219,10 @@ def train():
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2)
if train_stage == 4:
loss_head = criterion.generator_input_head_loss(x_in)
loss_head.backward()
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
loss_g.backward()
optimizer_ginka.step()
@ -246,24 +244,30 @@ def train():
else:
low_loss_epochs = 0
if low_loss_epochs >= 3 and train_stage == 1:
# 训练流程控制
if low_loss_epochs >= 3 and train_stage == 1 and stage_epoch >= args.curr_epoch:
if mask_ratio >= 0.9:
train_stage = 2
stage_epoch = 0
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
stage_epoch = 0
if train_stage == 3 or train_stage == 2:
if (train_stage == 3 or train_stage == 2) and not last_stage:
if stage_epoch >= 25:
train_stage += 1
stage_epoch = 0
if train_stage >= 3:
if train_stage == 4:
last_stage = True
if train_stage >= 3 or last_stage:
# 第三阶段后交叉熵损失不再应该生效
mask_ratio = 1.0
if last_stage:
mask_ratio = 1.0
if train_stage == 2 and stage_epoch % 5 == 0:
train_stage = 4
@ -317,7 +321,7 @@ def train():
elif train_stage == 3 or train_stage == 4:
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
fake1, fake2, fake3 = gen_total(ginka, input, True, True)
fake1, fake2, fake3, _ = gen_total(ginka, input, True, True)
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()