mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
perf: 改进生成流程与损失计算
This commit is contained in:
parent
fa48863946
commit
fb0323d874
@ -239,20 +239,20 @@ class MinamoModel2(nn.Module):
|
||||
self.head2 = MinamoHead2(256, 256)
|
||||
self.head3 = MinamoHead2(256, 256)
|
||||
|
||||
self.inject1 = ConditionInjector(256, 128)
|
||||
self.inject2 = ConditionInjector(256, 256)
|
||||
self.inject3 = ConditionInjector(256, 256)
|
||||
# self.inject1 = ConditionInjector(256, 128)
|
||||
# self.inject2 = ConditionInjector(256, 256)
|
||||
# self.inject3 = ConditionInjector(256, 256)
|
||||
|
||||
def forward(self, x, stage, tag_cond, val_cond):
|
||||
B, D = tag_cond.shape
|
||||
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
|
||||
cond = self.cond(tag_cond, val_cond, stage_tensor)
|
||||
x = self.conv1(x)
|
||||
x = self.inject1(x, cond)
|
||||
# x = self.inject1(x, cond)
|
||||
x = self.conv2(x)
|
||||
x = self.inject2(x, cond)
|
||||
# x = self.inject2(x, cond)
|
||||
x = self.conv3(x)
|
||||
x = self.inject3(x, cond)
|
||||
# x = self.inject3(x, cond)
|
||||
|
||||
if stage == 0:
|
||||
score = self.head0(x, cond)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -72,7 +73,33 @@ def apply_curriculum_mask(
|
||||
masked_maps[0, selected[:, 0], selected[:, 1]] = 1 # 置为“空地”
|
||||
|
||||
return removed_maps, masked_maps
|
||||
|
||||
def apply_curriculum_wall_mask(
|
||||
maps: torch.Tensor, # [C, H, W]
|
||||
mask_classes: List[int], # 要遮挡的类别索引
|
||||
remove_classes: List[int], # 要移除的类别索引
|
||||
mask_ratio: float # 遮挡比例 0~1
|
||||
) -> torch.Tensor:
|
||||
C, H, W = maps.shape
|
||||
masked_maps = maps.clone()
|
||||
|
||||
# Step 1: 移除不需要的类别(全设为 0 类)
|
||||
if remove_classes:
|
||||
remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
|
||||
masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
|
||||
masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
|
||||
|
||||
removed_maps = masked_maps.clone()
|
||||
|
||||
area = H * W * mask_ratio
|
||||
l = math.ceil(math.sqrt(area))
|
||||
nx = random.randint(0, W - l)
|
||||
ny = random.randint(0, H - l)
|
||||
masked_maps[mask_classes, nx:nx+l, ny:ny+l] = 0
|
||||
masked_maps[0, nx:nx+l, ny:ny+l] = 1
|
||||
|
||||
return removed_maps, masked_maps
|
||||
|
||||
class GinkaWGANDataset(Dataset):
|
||||
def __init__(self, data_path: str, device):
|
||||
self.data = load_data(data_path) # 自定义数据加载函数
|
||||
@ -87,11 +114,14 @@ class GinkaWGANDataset(Dataset):
|
||||
|
||||
def handle_stage1(self, target, tag_cond, val_cond):
|
||||
# 课程学习第一阶段,蒙版填充
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
|
||||
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
|
||||
rand = torch.rand(32, 32, 32, device=target.device)
|
||||
|
||||
return {
|
||||
"rand": rand,
|
||||
"real0": removed1,
|
||||
"real1": removed1,
|
||||
"masked1": masked1,
|
||||
"real2": removed2,
|
||||
@ -104,12 +134,15 @@ class GinkaWGANDataset(Dataset):
|
||||
|
||||
def handle_stage2(self, target, tag_cond, val_cond):
|
||||
# 课程学习第二阶段,完全随机蒙版
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
|
||||
rand = torch.rand(32, 32, 32, device=target.device)
|
||||
|
||||
return {
|
||||
"rand": rand,
|
||||
"real0": removed1,
|
||||
"real1": removed1,
|
||||
"masked1": masked1,
|
||||
"real2": removed2,
|
||||
@ -122,11 +155,14 @@ class GinkaWGANDataset(Dataset):
|
||||
|
||||
def handle_stage3(self, target, tag_cond, val_cond):
|
||||
# 第三阶段,联合生成,输入随机蒙版
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||
rand = torch.rand(32, 32, 32, device=target.device)
|
||||
|
||||
return {
|
||||
"rand": rand,
|
||||
"real0": removed1,
|
||||
"real1": removed1,
|
||||
"masked1": masked1,
|
||||
"real2": removed2,
|
||||
@ -142,14 +178,15 @@ class GinkaWGANDataset(Dataset):
|
||||
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
|
||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||
_, masked = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, 0.5)
|
||||
rand = torch.rand(32, 32, 32, device=target.device)
|
||||
|
||||
return {
|
||||
"rand": rand,
|
||||
"real0": removed1,
|
||||
"real1": removed1,
|
||||
"masked1": rand,
|
||||
"real2": removed2,
|
||||
"masked2": masked,
|
||||
"masked2": torch.zeros_like(target),
|
||||
"real3": removed3,
|
||||
"masked3": torch.zeros_like(target),
|
||||
"tag_cond": tag_cond,
|
||||
|
||||
@ -155,7 +155,7 @@ def input_head_illegal_loss(input_map, allowed_classes=[0, 1, 2]):
|
||||
C = input_map.shape[1]
|
||||
unallowed = get_not_allowed(allowed_classes, include_illegal=True)
|
||||
illegal = input_map[:, unallowed, :, :]
|
||||
penalty = torch.sum(illegal)
|
||||
penalty = F.l1_loss(illegal, torch.zeros_like(illegal, device=illegal.device))
|
||||
|
||||
return penalty
|
||||
|
||||
@ -254,7 +254,7 @@ def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
|
||||
return penalty
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.4, 50, 0.2, 0.2, 0.05, 0.4]):
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.4, 20, 0.2, 0.2, 0.05, 0.4]):
|
||||
# weight:
|
||||
# 1. 判别器损失及图块维持损失(可修改部分的已有内容不可修改)
|
||||
# 2. CE 损失
|
||||
@ -335,16 +335,14 @@ class WGANGinkaLoss:
|
||||
# 第一个阶段检查入口存在性
|
||||
entrance_loss = entrance_constraint_loss(probs_fake)
|
||||
losses.append(entrance_loss * self.weight[4])
|
||||
|
||||
# print(-js_divergence(fake_a, fake_b).item())
|
||||
|
||||
return sum(losses), minamo_loss, ce_loss, immutable_loss
|
||||
|
||||
return sum(losses), ce_loss
|
||||
|
||||
def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor:
|
||||
probs_fake = F.softmax(fake, dim=1)
|
||||
|
||||
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
|
||||
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
|
||||
constraint_loss = inner_constraint_loss(probs_fake)
|
||||
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
|
||||
@ -392,14 +390,13 @@ class WGANGinkaLoss:
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def generator_input_head_loss(self, critic, map: torch.Tensor, tag_cond, val_cond) -> torch.Tensor:
|
||||
probs = F.softmax(map, dim=1)
|
||||
head_scores = critic(probs, 0, tag_cond, val_cond)
|
||||
def generator_input_head_loss(self, critic, probs: torch.Tensor, tag_cond, val_cond) -> torch.Tensor:
|
||||
head_scores = -torch.mean(critic(probs, 0, tag_cond, val_cond))
|
||||
probs_a, probs_b = probs.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
torch.mean(head_scores),
|
||||
input_head_illegal_loss(probs),
|
||||
head_scores,
|
||||
input_head_illegal_loss(probs) * 50,
|
||||
-js_divergence(probs_a, probs_b, softmax=False) * 0.1
|
||||
]
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@ -20,18 +21,17 @@ class GinkaModel(nn.Module):
|
||||
self.unet = GinkaUNet(64, base_ch, base_ch)
|
||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||
|
||||
def forward(self, x, stage, tag_cond, val_cond, random=False):
|
||||
def forward(self, x, stage, tag_cond, val_cond):
|
||||
B, D = tag_cond.shape
|
||||
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
|
||||
cond = self.cond(tag_cond, val_cond, stage_tensor)
|
||||
if random:
|
||||
x_in = F.softmax(self.head(x, cond), dim=1)
|
||||
if stage == 0:
|
||||
x = self.head(x, cond)
|
||||
else:
|
||||
x_in = x
|
||||
x = self.input(x_in, cond)
|
||||
x = self.unet(x, cond)
|
||||
x = self.output(x, stage, cond)
|
||||
return x, x_in
|
||||
x = self.input(x, cond)
|
||||
x = self.unet(x, cond)
|
||||
x = self.output(x, stage, cond)
|
||||
return x
|
||||
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
@ -45,12 +45,18 @@ if __name__ == "__main__":
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
output, _ = model(input, 1, tag, val, True)
|
||||
start = time.perf_counter()
|
||||
fake0 = model(input, 0, tag, val)
|
||||
fake1 = model(F.softmax(fake0, dim=1), 1, tag, val)
|
||||
fake2 = model(F.softmax(fake1, dim=1), 1, tag, val)
|
||||
fake3 = model(F.softmax(fake2, dim=1), 1, tag, val)
|
||||
end = time.perf_counter()
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"推理耗时: {end - start}")
|
||||
print(f"输入形状: feat={input.shape}")
|
||||
print(f"输出形状: output={output.shape}")
|
||||
print(f"输出形状: output={fake3.shape}")
|
||||
print(f"Random parameters: {sum(p.numel() for p in model.head.parameters())}")
|
||||
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
|
||||
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
||||
|
||||
@ -6,19 +6,23 @@ from ..common.cond import ConditionInjector
|
||||
class StageHead(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
|
||||
super().__init__()
|
||||
self.dec = ConvFusionModule(in_ch, in_ch*2, in_ch, 32, 32)
|
||||
self.dec1 = ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32)
|
||||
self.dec2 = ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32)
|
||||
self.pool = nn.Sequential(
|
||||
ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32),
|
||||
ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32),
|
||||
ConvFusionModule(in_ch*2, in_ch*2, in_ch, 32, 32),
|
||||
|
||||
nn.AdaptiveMaxPool2d(out_size),
|
||||
nn.Conv2d(in_ch, out_ch, 1)
|
||||
)
|
||||
self.inject = ConditionInjector(256, in_ch)
|
||||
self.inject1 = ConditionInjector(256, in_ch*2)
|
||||
self.inject2 = ConditionInjector(256, in_ch*2)
|
||||
|
||||
def forward(self, x, cond):
|
||||
x = self.dec(x)
|
||||
x = self.inject(x, cond)
|
||||
x = self.dec1(x)
|
||||
x = self.inject1(x, cond)
|
||||
x = self.dec2(x)
|
||||
x = self.inject2(x, cond)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -47,7 +47,7 @@ from shared.image import matrix_to_image_cv
|
||||
# 29. 楼梯入口
|
||||
# 30. 箭头入口
|
||||
|
||||
BATCH_SIZE = 8
|
||||
BATCH_SIZE = 6
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
@ -71,39 +71,46 @@ def parse_arguments():
|
||||
return args
|
||||
|
||||
def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
fake1, _ = gen(masked1, 1, tag, val)
|
||||
fake2, _ = gen(masked2, 2, tag, val)
|
||||
fake3, _ = gen(masked3, 3, tag, val)
|
||||
fake1 = gen(masked1, 1, tag, val)
|
||||
fake2 = gen(masked2, 2, tag, val)
|
||||
fake3 = gen(masked3, 3, tag, val)
|
||||
if detach:
|
||||
return fake1.detach(), fake2.detach(), fake3.detach()
|
||||
else:
|
||||
return fake1, fake2, fake3
|
||||
|
||||
def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
|
||||
if random:
|
||||
fake0 = gen(input, 0, tag, val)
|
||||
x_in = F.softmax(fake0, dim=1)
|
||||
else:
|
||||
fake0 = input
|
||||
x_in = input
|
||||
if progress_detach:
|
||||
fake1, x_in = gen(input.detach(), 1, tag, val, random)
|
||||
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val)
|
||||
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val)
|
||||
fake1 = gen(x_in.detach(), 1, tag, val)
|
||||
fake2 = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val)
|
||||
fake3 = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val)
|
||||
else:
|
||||
fake1, x_in = gen(input, 1, tag, val, random)
|
||||
fake2, _ = gen(F.softmax(fake1, dim=1), 2, tag, val)
|
||||
fake3, _ = gen(F.softmax(fake2, dim=1), 3, tag, val)
|
||||
fake1 = gen(x_in, 1, tag, val)
|
||||
fake2 = gen(F.softmax(fake1, dim=1), 2, tag, val)
|
||||
fake3 = gen(F.softmax(fake2, dim=1), 3, tag, val)
|
||||
if result_detach:
|
||||
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
|
||||
return fake1.detach(), fake2.detach(), fake3.detach(), fake0.detach()
|
||||
else:
|
||||
return fake1, fake2, fake3, x_in
|
||||
return fake1, fake2, fake3, fake0
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
c_steps = 5
|
||||
c_steps = 2
|
||||
g_steps = 1
|
||||
# 训练阶段
|
||||
train_stage = 1
|
||||
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
||||
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
||||
total_epoch = 0
|
||||
|
||||
ginka = GinkaModel().to(device)
|
||||
minamo = MinamoModel2().to(device)
|
||||
@ -114,7 +121,7 @@ def train():
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
||||
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
|
||||
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
||||
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
|
||||
@ -134,9 +141,9 @@ def train():
|
||||
ginka.load_state_dict(data_ginka["model_state"], strict=False)
|
||||
minamo.load_state_dict(data_minamo["model_state"], strict=False)
|
||||
|
||||
if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None:
|
||||
c_steps = data_ginka["c_steps"]
|
||||
g_steps = data_ginka["g_steps"]
|
||||
# if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None:
|
||||
# c_steps = data_ginka["c_steps"]
|
||||
# g_steps = data_ginka["g_steps"]
|
||||
|
||||
if data_ginka.get("mask_ratio") is not None:
|
||||
mask_ratio = data_ginka["mask_ratio"]
|
||||
@ -147,6 +154,9 @@ def train():
|
||||
if data_ginka.get("stage") is not None:
|
||||
train_stage = data_ginka["stage"]
|
||||
|
||||
if data_ginka.get("total_epoch") is not None:
|
||||
total_epoch = data_ginka["data_ginka"]
|
||||
|
||||
if args.load_optim:
|
||||
if data_ginka.get("optim_state") is not None:
|
||||
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
||||
@ -156,6 +166,7 @@ def train():
|
||||
print("Train from loaded state.")
|
||||
|
||||
curr_epoch = args.curr_epoch
|
||||
first_curr = curr_epoch * 3
|
||||
|
||||
if args.tuning:
|
||||
train_stage = 1
|
||||
@ -182,6 +193,8 @@ def train():
|
||||
loss_ce_total = torch.Tensor([0]).to(device)
|
||||
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
rand = batch["rand"].to(device)
|
||||
real0 = batch["real0"].to(device)
|
||||
real1 = batch["real1"].to(device)
|
||||
masked1 = batch["masked1"].to(device)
|
||||
real2 = batch["real2"].to(device)
|
||||
@ -200,23 +213,19 @@ def train():
|
||||
with torch.no_grad():
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
|
||||
if train_stage == 4:
|
||||
loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, masked2, x_in, tag_cond, val_cond)
|
||||
|
||||
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
|
||||
if train_stage < 4:
|
||||
fake0 = ginka(rand, 0, tag_cond, val_cond)
|
||||
|
||||
loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, real0, fake0, tag_cond, val_cond)
|
||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond)
|
||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond)
|
||||
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond)
|
||||
|
||||
dis = [dis1, dis2, dis3]
|
||||
loss_d = [loss_d1, loss_d2, loss_d3]
|
||||
|
||||
if train_stage == 4:
|
||||
dis.append(dis0)
|
||||
loss_d.append(loss_d0)
|
||||
dis = [dis0, dis1, dis2, dis3]
|
||||
loss_d = [loss_d0, loss_d1, loss_d2, loss_d3]
|
||||
|
||||
dis_avg = sum(dis) / len(dis)
|
||||
loss_d_avg = sum(loss_d) / len(loss_d)
|
||||
@ -237,33 +246,35 @@ def train():
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, False)
|
||||
|
||||
loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond)
|
||||
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond)
|
||||
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond)
|
||||
loss_g1, loss_ce_g1 = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond)
|
||||
loss_g2, loss_ce_g2 = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond)
|
||||
loss_g3, loss_ce_g3 = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond)
|
||||
|
||||
loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0
|
||||
loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0
|
||||
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
|
||||
|
||||
loss_g.backward()
|
||||
optimizer_ginka.step()
|
||||
loss_total_ginka += loss_g.detach()
|
||||
loss_ce_total += loss_ce.detach()
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4)
|
||||
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4)
|
||||
if train_stage == 4:
|
||||
fake0 = F.softmax(fake0, dim=1)
|
||||
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, x_in, tag_cond, val_cond)
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, fake0, tag_cond, val_cond)
|
||||
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond)
|
||||
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond)
|
||||
|
||||
if train_stage == 4:
|
||||
loss_head = criterion.generator_input_head_loss(minamo, x_in, tag_cond, val_cond)
|
||||
loss_head.backward(retain_graph=True)
|
||||
loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0
|
||||
|
||||
if train_stage < 4:
|
||||
fake0 = F.softmax(ginka(rand, 0, tag_cond, val_cond), dim=1)
|
||||
|
||||
loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0
|
||||
loss_g.backward()
|
||||
optimizer_ginka.step()
|
||||
loss_total_ginka += loss_g.detach()
|
||||
loss_g0 = criterion.generator_input_head_loss(minamo, fake0, tag_cond, val_cond)
|
||||
loss_g += loss_g0
|
||||
|
||||
loss_g.backward()
|
||||
optimizer_ginka.step()
|
||||
loss_total_ginka += loss_g.detach()
|
||||
|
||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
||||
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
||||
@ -311,8 +322,8 @@ def train():
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
x_in = torch.argmax(x_in, dim=1).cpu().numpy()
|
||||
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
fake0 = torch.argmax(fake0, dim=1).cpu().numpy()
|
||||
|
||||
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
||||
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
||||
@ -339,7 +350,7 @@ def train():
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
|
||||
hline = np.full((gap, 2 * 416 + gap, 3), color, dtype=np.uint8) # 水平分割线
|
||||
in_img = matrix_to_image_cv(x_in[i], tile_dict)
|
||||
in_img = matrix_to_image_cv(fake0[i], tile_dict)
|
||||
img = np.block([
|
||||
[[in_img], [vline], [fake1_img]],
|
||||
[[hline]],
|
||||
@ -352,46 +363,42 @@ def train():
|
||||
|
||||
# 训练流程控制
|
||||
|
||||
if mask_ratio < 0.5 and avg_loss_ce < 0.5:
|
||||
low_loss_epochs += 1
|
||||
elif mask_ratio > 0.5 and avg_loss_ce < 0.5:
|
||||
low_loss_epochs += 1
|
||||
else:
|
||||
low_loss_epochs = 0
|
||||
|
||||
if train_stage >= 2:
|
||||
if (epoch + 1) % 5 == 1:
|
||||
if (epoch + 1) % 10 == 1:
|
||||
train_stage = 3
|
||||
elif (epoch + 1) % 5 == 3:
|
||||
elif (epoch + 1) % 10 == 3:
|
||||
train_stage = 4
|
||||
elif (epoch + 1) % 5 == 0:
|
||||
elif (epoch + 1) % 10 == 0:
|
||||
train_stage = 2
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
mask_ratio += 0.2
|
||||
mask_ratio = min(mask_ratio, 0.9)
|
||||
low_loss_epochs = 0
|
||||
stage_epoch = 0
|
||||
if train_stage == 1:
|
||||
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
|
||||
(mask_ratio > 0.3 and stage_epoch >= curr_epoch):
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
mask_ratio += 0.2
|
||||
mask_ratio = min(mask_ratio, 0.9)
|
||||
low_loss_epochs = 0
|
||||
stage_epoch = 0
|
||||
|
||||
stage_epoch += 1
|
||||
total_epoch += 1
|
||||
|
||||
# scheduler_ginka.step()
|
||||
# scheduler_minamo.step()
|
||||
|
||||
if avg_dis < 0:
|
||||
g_steps = max(int(-avg_dis * 5), 1)
|
||||
else:
|
||||
g_steps = 1
|
||||
# if avg_dis < 0:
|
||||
# g_steps = max(int(-avg_dis * 5), 1)
|
||||
# else:
|
||||
# g_steps = 1
|
||||
|
||||
# if avg_loss_ginka > 0 and epoch > 20 and not args.resume:
|
||||
# g_steps += int(min(avg_loss_ginka * 5, 50))
|
||||
|
||||
if avg_loss_minamo > 0:
|
||||
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
||||
else:
|
||||
c_steps = 5
|
||||
# if avg_loss_minamo > 0:
|
||||
# c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
||||
# else:
|
||||
# c_steps = 5
|
||||
|
||||
dataset.train_stage = train_stage
|
||||
dataset_val.train_stage = train_stage
|
||||
|
||||
Loading…
Reference in New Issue
Block a user