perf: 改进生成流程与损失计算

This commit is contained in:
unanmed 2025-05-15 17:44:29 +08:00
parent fa48863946
commit fb0323d874
6 changed files with 162 additions and 111 deletions

View File

@ -239,20 +239,20 @@ class MinamoModel2(nn.Module):
self.head2 = MinamoHead2(256, 256)
self.head3 = MinamoHead2(256, 256)
self.inject1 = ConditionInjector(256, 128)
self.inject2 = ConditionInjector(256, 256)
self.inject3 = ConditionInjector(256, 256)
# self.inject1 = ConditionInjector(256, 128)
# self.inject2 = ConditionInjector(256, 256)
# self.inject3 = ConditionInjector(256, 256)
def forward(self, x, stage, tag_cond, val_cond):
B, D = tag_cond.shape
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
cond = self.cond(tag_cond, val_cond, stage_tensor)
x = self.conv1(x)
x = self.inject1(x, cond)
# x = self.inject1(x, cond)
x = self.conv2(x)
x = self.inject2(x, cond)
# x = self.inject2(x, cond)
x = self.conv3(x)
x = self.inject3(x, cond)
# x = self.inject3(x, cond)
if stage == 0:
score = self.head0(x, cond)

View File

@ -1,4 +1,5 @@
import json
import math
import random
import torch
import torch.nn.functional as F
@ -72,7 +73,33 @@ def apply_curriculum_mask(
masked_maps[0, selected[:, 0], selected[:, 1]] = 1 # 置为“空地”
return removed_maps, masked_maps
def apply_curriculum_wall_mask(
maps: torch.Tensor, # [C, H, W]
mask_classes: List[int], # 要遮挡的类别索引
remove_classes: List[int], # 要移除的类别索引
mask_ratio: float # 遮挡比例 0~1
) -> torch.Tensor:
C, H, W = maps.shape
masked_maps = maps.clone()
# Step 1: 移除不需要的类别(全设为 0 类)
if remove_classes:
remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
removed_maps = masked_maps.clone()
area = H * W * mask_ratio
l = math.ceil(math.sqrt(area))
nx = random.randint(0, W - l)
ny = random.randint(0, H - l)
masked_maps[mask_classes, nx:nx+l, ny:ny+l] = 0
masked_maps[0, nx:nx+l, ny:ny+l] = 1
return removed_maps, masked_maps
class GinkaWGANDataset(Dataset):
def __init__(self, data_path: str, device):
self.data = load_data(data_path) # 自定义数据加载函数
@ -87,11 +114,14 @@ class GinkaWGANDataset(Dataset):
def handle_stage1(self, target, tag_cond, val_cond):
# 课程学习第一阶段,蒙版填充
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
rand = torch.rand(32, 32, 32, device=target.device)
return {
"rand": rand,
"real0": removed1,
"real1": removed1,
"masked1": masked1,
"real2": removed2,
@ -104,12 +134,15 @@ class GinkaWGANDataset(Dataset):
def handle_stage2(self, target, tag_cond, val_cond):
# 课程学习第二阶段,完全随机蒙版
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
rand = torch.rand(32, 32, 32, device=target.device)
return {
"rand": rand,
"real0": removed1,
"real1": removed1,
"masked1": masked1,
"real2": removed2,
@ -122,11 +155,14 @@ class GinkaWGANDataset(Dataset):
def handle_stage3(self, target, tag_cond, val_cond):
# 第三阶段,联合生成,输入随机蒙版
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
rand = torch.rand(32, 32, 32, device=target.device)
return {
"rand": rand,
"real0": removed1,
"real1": removed1,
"masked1": masked1,
"real2": removed2,
@ -142,14 +178,15 @@ class GinkaWGANDataset(Dataset):
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
_, masked = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, 0.5)
rand = torch.rand(32, 32, 32, device=target.device)
return {
"rand": rand,
"real0": removed1,
"real1": removed1,
"masked1": rand,
"real2": removed2,
"masked2": masked,
"masked2": torch.zeros_like(target),
"real3": removed3,
"masked3": torch.zeros_like(target),
"tag_cond": tag_cond,

View File

@ -155,7 +155,7 @@ def input_head_illegal_loss(input_map, allowed_classes=[0, 1, 2]):
C = input_map.shape[1]
unallowed = get_not_allowed(allowed_classes, include_illegal=True)
illegal = input_map[:, unallowed, :, :]
penalty = torch.sum(illegal)
penalty = F.l1_loss(illegal, torch.zeros_like(illegal, device=illegal.device))
return penalty
@ -254,7 +254,7 @@ def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]):
return penalty
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.4, 50, 0.2, 0.2, 0.05, 0.4]):
def __init__(self, lambda_gp=100, weight=[1, 0.4, 20, 0.2, 0.2, 0.05, 0.4]):
# weight:
# 1. 判别器损失及图块维持损失(可修改部分的已有内容不可修改)
# 2. CE 损失
@ -335,16 +335,14 @@ class WGANGinkaLoss:
# 第一个阶段检查入口存在性
entrance_loss = entrance_constraint_loss(probs_fake)
losses.append(entrance_loss * self.weight[4])
# print(-js_divergence(fake_a, fake_b).item())
return sum(losses), minamo_loss, ce_loss, immutable_loss
return sum(losses), ce_loss
def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor:
probs_fake = F.softmax(fake, dim=1)
fake_scores = critic(probs_fake, stage, tag_cond, val_cond)
minamo_loss = -torch.mean(fake_scores)
minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage])
illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage])
constraint_loss = inner_constraint_loss(probs_fake)
density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage])
@ -392,14 +390,13 @@ class WGANGinkaLoss:
return sum(losses)
def generator_input_head_loss(self, critic, map: torch.Tensor, tag_cond, val_cond) -> torch.Tensor:
probs = F.softmax(map, dim=1)
head_scores = critic(probs, 0, tag_cond, val_cond)
def generator_input_head_loss(self, critic, probs: torch.Tensor, tag_cond, val_cond) -> torch.Tensor:
head_scores = -torch.mean(critic(probs, 0, tag_cond, val_cond))
probs_a, probs_b = probs.chunk(2, dim=0)
losses = [
torch.mean(head_scores),
input_head_illegal_loss(probs),
head_scores,
input_head_illegal_loss(probs) * 50,
-js_divergence(probs_a, probs_b, softmax=False) * 0.1
]

View File

@ -1,3 +1,4 @@
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
@ -20,18 +21,17 @@ class GinkaModel(nn.Module):
self.unet = GinkaUNet(64, base_ch, base_ch)
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
def forward(self, x, stage, tag_cond, val_cond, random=False):
def forward(self, x, stage, tag_cond, val_cond):
B, D = tag_cond.shape
stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device)
cond = self.cond(tag_cond, val_cond, stage_tensor)
if random:
x_in = F.softmax(self.head(x, cond), dim=1)
if stage == 0:
x = self.head(x, cond)
else:
x_in = x
x = self.input(x_in, cond)
x = self.unet(x, cond)
x = self.output(x, stage, cond)
return x, x_in
x = self.input(x, cond)
x = self.unet(x, cond)
x = self.output(x, stage, cond)
return x
# 检查显存占用
if __name__ == "__main__":
@ -45,12 +45,18 @@ if __name__ == "__main__":
print_memory("初始化后")
# 前向传播
output, _ = model(input, 1, tag, val, True)
start = time.perf_counter()
fake0 = model(input, 0, tag, val)
fake1 = model(F.softmax(fake0, dim=1), 1, tag, val)
fake2 = model(F.softmax(fake1, dim=1), 1, tag, val)
fake3 = model(F.softmax(fake2, dim=1), 1, tag, val)
end = time.perf_counter()
print_memory("前向传播后")
print(f"推理耗时: {end - start}")
print(f"输入形状: feat={input.shape}")
print(f"输出形状: output={output.shape}")
print(f"输出形状: output={fake3.shape}")
print(f"Random parameters: {sum(p.numel() for p in model.head.parameters())}")
print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}")
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")

View File

@ -6,19 +6,23 @@ from ..common.cond import ConditionInjector
class StageHead(nn.Module):
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
super().__init__()
self.dec = ConvFusionModule(in_ch, in_ch*2, in_ch, 32, 32)
self.dec1 = ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32)
self.dec2 = ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32)
self.pool = nn.Sequential(
ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32),
ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32),
ConvFusionModule(in_ch*2, in_ch*2, in_ch, 32, 32),
nn.AdaptiveMaxPool2d(out_size),
nn.Conv2d(in_ch, out_ch, 1)
)
self.inject = ConditionInjector(256, in_ch)
self.inject1 = ConditionInjector(256, in_ch*2)
self.inject2 = ConditionInjector(256, in_ch*2)
def forward(self, x, cond):
x = self.dec(x)
x = self.inject(x, cond)
x = self.dec1(x)
x = self.inject1(x, cond)
x = self.dec2(x)
x = self.inject2(x, cond)
x = self.pool(x)
return x

View File

@ -47,7 +47,7 @@ from shared.image import matrix_to_image_cv
# 29. 楼梯入口
# 30. 箭头入口
BATCH_SIZE = 8
BATCH_SIZE = 6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
@ -71,39 +71,46 @@ def parse_arguments():
return args
def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
fake1, _ = gen(masked1, 1, tag, val)
fake2, _ = gen(masked2, 2, tag, val)
fake3, _ = gen(masked3, 3, tag, val)
fake1 = gen(masked1, 1, tag, val)
fake2 = gen(masked2, 2, tag, val)
fake3 = gen(masked3, 3, tag, val)
if detach:
return fake1.detach(), fake2.detach(), fake3.detach()
else:
return fake1, fake2, fake3
def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
if random:
fake0 = gen(input, 0, tag, val)
x_in = F.softmax(fake0, dim=1)
else:
fake0 = input
x_in = input
if progress_detach:
fake1, x_in = gen(input.detach(), 1, tag, val, random)
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val)
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val)
fake1 = gen(x_in.detach(), 1, tag, val)
fake2 = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val)
fake3 = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val)
else:
fake1, x_in = gen(input, 1, tag, val, random)
fake2, _ = gen(F.softmax(fake1, dim=1), 2, tag, val)
fake3, _ = gen(F.softmax(fake2, dim=1), 3, tag, val)
fake1 = gen(x_in, 1, tag, val)
fake2 = gen(F.softmax(fake1, dim=1), 2, tag, val)
fake3 = gen(F.softmax(fake2, dim=1), 3, tag, val)
if result_detach:
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
return fake1.detach(), fake2.detach(), fake3.detach(), fake0.detach()
else:
return fake1, fake2, fake3, x_in
return fake1, fake2, fake3, fake0
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments()
c_steps = 5
c_steps = 2
g_steps = 1
# 训练阶段
train_stage = 1
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
total_epoch = 0
ginka = GinkaModel().to(device)
minamo = MinamoModel2().to(device)
@ -114,7 +121,7 @@ def train():
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9))
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
@ -134,9 +141,9 @@ def train():
ginka.load_state_dict(data_ginka["model_state"], strict=False)
minamo.load_state_dict(data_minamo["model_state"], strict=False)
if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None:
c_steps = data_ginka["c_steps"]
g_steps = data_ginka["g_steps"]
# if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None:
# c_steps = data_ginka["c_steps"]
# g_steps = data_ginka["g_steps"]
if data_ginka.get("mask_ratio") is not None:
mask_ratio = data_ginka["mask_ratio"]
@ -147,6 +154,9 @@ def train():
if data_ginka.get("stage") is not None:
train_stage = data_ginka["stage"]
if data_ginka.get("total_epoch") is not None:
total_epoch = data_ginka["data_ginka"]
if args.load_optim:
if data_ginka.get("optim_state") is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
@ -156,6 +166,7 @@ def train():
print("Train from loaded state.")
curr_epoch = args.curr_epoch
first_curr = curr_epoch * 3
if args.tuning:
train_stage = 1
@ -182,6 +193,8 @@ def train():
loss_ce_total = torch.Tensor([0]).to(device)
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
rand = batch["rand"].to(device)
real0 = batch["real0"].to(device)
real1 = batch["real1"].to(device)
masked1 = batch["masked1"].to(device)
real2 = batch["real2"].to(device)
@ -200,23 +213,19 @@ def train():
with torch.no_grad():
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
if train_stage == 4:
loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, masked2, x_in, tag_cond, val_cond)
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
if train_stage < 4:
fake0 = ginka(rand, 0, tag_cond, val_cond)
loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, real0, fake0, tag_cond, val_cond)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond)
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond)
dis = [dis1, dis2, dis3]
loss_d = [loss_d1, loss_d2, loss_d3]
if train_stage == 4:
dis.append(dis0)
loss_d.append(loss_d0)
dis = [dis0, dis1, dis2, dis3]
loss_d = [loss_d0, loss_d1, loss_d2, loss_d3]
dis_avg = sum(dis) / len(dis)
loss_d_avg = sum(loss_d) / len(loss_d)
@ -237,33 +246,35 @@ def train():
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, False)
loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond)
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond)
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond)
loss_g1, loss_ce_g1 = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond)
loss_g2, loss_ce_g2 = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond)
loss_g3, loss_ce_g3 = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond)
loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0
loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g.detach()
loss_ce_total += loss_ce.detach()
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4)
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4)
if train_stage == 4:
fake0 = F.softmax(fake0, dim=1)
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, x_in, tag_cond, val_cond)
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, fake0, tag_cond, val_cond)
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond)
loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond)
if train_stage == 4:
loss_head = criterion.generator_input_head_loss(minamo, x_in, tag_cond, val_cond)
loss_head.backward(retain_graph=True)
loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0
if train_stage < 4:
fake0 = F.softmax(ginka(rand, 0, tag_cond, val_cond), dim=1)
loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g.detach()
loss_g0 = criterion.generator_input_head_loss(minamo, fake0, tag_cond, val_cond)
loss_g += loss_g0
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g.detach()
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
@ -311,8 +322,8 @@ def train():
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
x_in = torch.argmax(x_in, dim=1).cpu().numpy()
fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
fake0 = torch.argmax(fake0, dim=1).cpu().numpy()
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
@ -339,7 +350,7 @@ def train():
elif train_stage == 3 or train_stage == 4:
vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线
hline = np.full((gap, 2 * 416 + gap, 3), color, dtype=np.uint8) # 水平分割线
in_img = matrix_to_image_cv(x_in[i], tile_dict)
in_img = matrix_to_image_cv(fake0[i], tile_dict)
img = np.block([
[[in_img], [vline], [fake1_img]],
[[hline]],
@ -352,46 +363,42 @@ def train():
# 训练流程控制
if mask_ratio < 0.5 and avg_loss_ce < 0.5:
low_loss_epochs += 1
elif mask_ratio > 0.5 and avg_loss_ce < 0.5:
low_loss_epochs += 1
else:
low_loss_epochs = 0
if train_stage >= 2:
if (epoch + 1) % 5 == 1:
if (epoch + 1) % 10 == 1:
train_stage = 3
elif (epoch + 1) % 5 == 3:
elif (epoch + 1) % 10 == 3:
train_stage = 4
elif (epoch + 1) % 5 == 0:
elif (epoch + 1) % 10 == 0:
train_stage = 2
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
stage_epoch = 0
if train_stage == 1:
if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \
(mask_ratio > 0.3 and stage_epoch >= curr_epoch):
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
stage_epoch = 0
stage_epoch += 1
total_epoch += 1
# scheduler_ginka.step()
# scheduler_minamo.step()
if avg_dis < 0:
g_steps = max(int(-avg_dis * 5), 1)
else:
g_steps = 1
# if avg_dis < 0:
# g_steps = max(int(-avg_dis * 5), 1)
# else:
# g_steps = 1
# if avg_loss_ginka > 0 and epoch > 20 and not args.resume:
# g_steps += int(min(avg_loss_ginka * 5, 50))
if avg_loss_minamo > 0:
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
else:
c_steps = 5
# if avg_loss_minamo > 0:
# c_steps = int(min(5 + avg_loss_minamo * 5, 15))
# else:
# c_steps = 5
dataset.train_stage = train_stage
dataset_val.train_stage = train_stage