diff --git a/ginka/critic/model.py b/ginka/critic/model.py index 548d451..29620ee 100644 --- a/ginka/critic/model.py +++ b/ginka/critic/model.py @@ -239,20 +239,20 @@ class MinamoModel2(nn.Module): self.head2 = MinamoHead2(256, 256) self.head3 = MinamoHead2(256, 256) - self.inject1 = ConditionInjector(256, 128) - self.inject2 = ConditionInjector(256, 256) - self.inject3 = ConditionInjector(256, 256) + # self.inject1 = ConditionInjector(256, 128) + # self.inject2 = ConditionInjector(256, 256) + # self.inject3 = ConditionInjector(256, 256) def forward(self, x, stage, tag_cond, val_cond): B, D = tag_cond.shape stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device) cond = self.cond(tag_cond, val_cond, stage_tensor) x = self.conv1(x) - x = self.inject1(x, cond) + # x = self.inject1(x, cond) x = self.conv2(x) - x = self.inject2(x, cond) + # x = self.inject2(x, cond) x = self.conv3(x) - x = self.inject3(x, cond) + # x = self.inject3(x, cond) if stage == 0: score = self.head0(x, cond) diff --git a/ginka/dataset.py b/ginka/dataset.py index 886afda..c7f6833 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -1,4 +1,5 @@ import json +import math import random import torch import torch.nn.functional as F @@ -72,7 +73,33 @@ def apply_curriculum_mask( masked_maps[0, selected[:, 0], selected[:, 1]] = 1 # 置为“空地” return removed_maps, masked_maps + +def apply_curriculum_wall_mask( + maps: torch.Tensor, # [C, H, W] + mask_classes: List[int], # 要遮挡的类别索引 + remove_classes: List[int], # 要移除的类别索引 + mask_ratio: float # 遮挡比例 0~1 +) -> torch.Tensor: + C, H, W = maps.shape + masked_maps = maps.clone() + + # Step 1: 移除不需要的类别(全设为 0 类) + if remove_classes: + remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0 + masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0 + masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地” + removed_maps = masked_maps.clone() + + area = H * W * mask_ratio + l = math.ceil(math.sqrt(area)) + nx = random.randint(0, W - l) + ny = random.randint(0, H - l) + masked_maps[mask_classes, nx:nx+l, ny:ny+l] = 0 + masked_maps[0, nx:nx+l, ny:ny+l] = 1 + + return removed_maps, masked_maps + class GinkaWGANDataset(Dataset): def __init__(self, data_path: str, device): self.data = load_data(data_path) # 自定义数据加载函数 @@ -87,11 +114,14 @@ class GinkaWGANDataset(Dataset): def handle_stage1(self, target, tag_cond, val_cond): # 课程学习第一阶段,蒙版填充 - removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1) + removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1) removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2) removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3) + rand = torch.rand(32, 32, 32, device=target.device) return { + "rand": rand, + "real0": removed1, "real1": removed1, "masked1": masked1, "real2": removed2, @@ -104,12 +134,15 @@ class GinkaWGANDataset(Dataset): def handle_stage2(self, target, tag_cond, val_cond): # 课程学习第二阶段,完全随机蒙版 - removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) + removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) # 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可 removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1)) removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1)) + rand = torch.rand(32, 32, 32, device=target.device) return { + "rand": rand, + "real0": removed1, "real1": removed1, "masked1": masked1, "real2": removed2, @@ -122,11 +155,14 @@ class GinkaWGANDataset(Dataset): def handle_stage3(self, target, tag_cond, val_cond): # 第三阶段,联合生成,输入随机蒙版 - removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) + removed1, masked1 = apply_curriculum_wall_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) + rand = torch.rand(32, 32, 32, device=target.device) return { + "rand": rand, + "real0": removed1, "real1": removed1, "masked1": masked1, "real2": removed2, @@ -142,14 +178,15 @@ class GinkaWGANDataset(Dataset): removed1 = apply_curriculum_remove(target, STAGE1_REMOVE) removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) - _, masked = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, 0.5) rand = torch.rand(32, 32, 32, device=target.device) return { + "rand": rand, + "real0": removed1, "real1": removed1, "masked1": rand, "real2": removed2, - "masked2": masked, + "masked2": torch.zeros_like(target), "real3": removed3, "masked3": torch.zeros_like(target), "tag_cond": tag_cond, diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 44b59b5..2d772ad 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -155,7 +155,7 @@ def input_head_illegal_loss(input_map, allowed_classes=[0, 1, 2]): C = input_map.shape[1] unallowed = get_not_allowed(allowed_classes, include_illegal=True) illegal = input_map[:, unallowed, :, :] - penalty = torch.sum(illegal) + penalty = F.l1_loss(illegal, torch.zeros_like(illegal, device=illegal.device)) return penalty @@ -254,7 +254,7 @@ def illegal_penalty_loss(pred: torch.Tensor, legal_classes: list[int]): return penalty class WGANGinkaLoss: - def __init__(self, lambda_gp=100, weight=[1, 0.4, 50, 0.2, 0.2, 0.05, 0.4]): + def __init__(self, lambda_gp=100, weight=[1, 0.4, 20, 0.2, 0.2, 0.05, 0.4]): # weight: # 1. 判别器损失及图块维持损失(可修改部分的已有内容不可修改) # 2. CE 损失 @@ -335,16 +335,14 @@ class WGANGinkaLoss: # 第一个阶段检查入口存在性 entrance_loss = entrance_constraint_loss(probs_fake) losses.append(entrance_loss * self.weight[4]) - - # print(-js_divergence(fake_a, fake_b).item()) - - return sum(losses), minamo_loss, ce_loss, immutable_loss + + return sum(losses), ce_loss def generator_loss_total(self, critic, stage, fake, tag_cond, val_cond) -> torch.Tensor: probs_fake = F.softmax(fake, dim=1) fake_scores = critic(probs_fake, stage, tag_cond, val_cond) - minamo_loss = -torch.mean(fake_scores) + minamo_loss = -torch.mean(fake_scores) + modifiable_penalty_loss(probs_fake, input, STAGE_CHANGEABLE[stage]) illegal_loss = illegal_penalty_loss(probs_fake, STAGE_ALLOWED[stage]) constraint_loss = inner_constraint_loss(probs_fake) density_loss = compute_multi_density_loss(probs_fake, val_cond, DENSITY_STAGE[stage]) @@ -392,14 +390,13 @@ class WGANGinkaLoss: return sum(losses) - def generator_input_head_loss(self, critic, map: torch.Tensor, tag_cond, val_cond) -> torch.Tensor: - probs = F.softmax(map, dim=1) - head_scores = critic(probs, 0, tag_cond, val_cond) + def generator_input_head_loss(self, critic, probs: torch.Tensor, tag_cond, val_cond) -> torch.Tensor: + head_scores = -torch.mean(critic(probs, 0, tag_cond, val_cond)) probs_a, probs_b = probs.chunk(2, dim=0) losses = [ - torch.mean(head_scores), - input_head_illegal_loss(probs), + head_scores, + input_head_illegal_loss(probs) * 50, -js_divergence(probs_a, probs_b, softmax=False) * 0.1 ] diff --git a/ginka/generator/model.py b/ginka/generator/model.py index 30991a1..8244bd7 100644 --- a/ginka/generator/model.py +++ b/ginka/generator/model.py @@ -1,3 +1,4 @@ +import time import torch import torch.nn as nn import torch.nn.functional as F @@ -20,18 +21,17 @@ class GinkaModel(nn.Module): self.unet = GinkaUNet(64, base_ch, base_ch) self.output = GinkaOutput(base_ch, out_ch, (13, 13)) - def forward(self, x, stage, tag_cond, val_cond, random=False): + def forward(self, x, stage, tag_cond, val_cond): B, D = tag_cond.shape stage_tensor = torch.Tensor([stage]).expand(B, 1).to(x.device) cond = self.cond(tag_cond, val_cond, stage_tensor) - if random: - x_in = F.softmax(self.head(x, cond), dim=1) + if stage == 0: + x = self.head(x, cond) else: - x_in = x - x = self.input(x_in, cond) - x = self.unet(x, cond) - x = self.output(x, stage, cond) - return x, x_in + x = self.input(x, cond) + x = self.unet(x, cond) + x = self.output(x, stage, cond) + return x # 检查显存占用 if __name__ == "__main__": @@ -45,12 +45,18 @@ if __name__ == "__main__": print_memory("初始化后") # 前向传播 - output, _ = model(input, 1, tag, val, True) + start = time.perf_counter() + fake0 = model(input, 0, tag, val) + fake1 = model(F.softmax(fake0, dim=1), 1, tag, val) + fake2 = model(F.softmax(fake1, dim=1), 1, tag, val) + fake3 = model(F.softmax(fake2, dim=1), 1, tag, val) + end = time.perf_counter() print_memory("前向传播后") + print(f"推理耗时: {end - start}") print(f"输入形状: feat={input.shape}") - print(f"输出形状: output={output.shape}") + print(f"输出形状: output={fake3.shape}") print(f"Random parameters: {sum(p.numel() for p in model.head.parameters())}") print(f"Cond parameters: {sum(p.numel() for p in model.cond.parameters())}") print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}") diff --git a/ginka/generator/output.py b/ginka/generator/output.py index b63cac4..25e038c 100644 --- a/ginka/generator/output.py +++ b/ginka/generator/output.py @@ -6,19 +6,23 @@ from ..common.cond import ConditionInjector class StageHead(nn.Module): def __init__(self, in_ch, out_ch, out_size=(13, 13)): super().__init__() - self.dec = ConvFusionModule(in_ch, in_ch*2, in_ch, 32, 32) + self.dec1 = ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32) + self.dec2 = ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32) self.pool = nn.Sequential( - ConvFusionModule(in_ch, in_ch*2, in_ch*2, 32, 32), + ConvFusionModule(in_ch*2, in_ch*2, in_ch*2, 32, 32), ConvFusionModule(in_ch*2, in_ch*2, in_ch, 32, 32), nn.AdaptiveMaxPool2d(out_size), nn.Conv2d(in_ch, out_ch, 1) ) - self.inject = ConditionInjector(256, in_ch) + self.inject1 = ConditionInjector(256, in_ch*2) + self.inject2 = ConditionInjector(256, in_ch*2) def forward(self, x, cond): - x = self.dec(x) - x = self.inject(x, cond) + x = self.dec1(x) + x = self.inject1(x, cond) + x = self.dec2(x) + x = self.inject2(x, cond) x = self.pool(x) return x diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 5c700c5..d5930a2 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -47,7 +47,7 @@ from shared.image import matrix_to_image_cv # 29. 楼梯入口 # 30. 箭头入口 -BATCH_SIZE = 8 +BATCH_SIZE = 6 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) @@ -71,39 +71,46 @@ def parse_arguments(): return args def gen_curriculum(gen, masked1, masked2, masked3, tag, val, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - fake1, _ = gen(masked1, 1, tag, val) - fake2, _ = gen(masked2, 2, tag, val) - fake3, _ = gen(masked3, 3, tag, val) + fake1 = gen(masked1, 1, tag, val) + fake2 = gen(masked2, 2, tag, val) + fake3 = gen(masked3, 3, tag, val) if detach: return fake1.detach(), fake2.detach(), fake3.detach() else: return fake1, fake2, fake3 def gen_total(gen, input, tag, val, progress_detach=True, result_detach=False, random=False) -> torch.Tensor: + if random: + fake0 = gen(input, 0, tag, val) + x_in = F.softmax(fake0, dim=1) + else: + fake0 = input + x_in = input if progress_detach: - fake1, x_in = gen(input.detach(), 1, tag, val, random) - fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val) - fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val) + fake1 = gen(x_in.detach(), 1, tag, val) + fake2 = gen(F.softmax(fake1.detach(), dim=1), 2, tag, val) + fake3 = gen(F.softmax(fake2.detach(), dim=1), 3, tag, val) else: - fake1, x_in = gen(input, 1, tag, val, random) - fake2, _ = gen(F.softmax(fake1, dim=1), 2, tag, val) - fake3, _ = gen(F.softmax(fake2, dim=1), 3, tag, val) + fake1 = gen(x_in, 1, tag, val) + fake2 = gen(F.softmax(fake1, dim=1), 2, tag, val) + fake3 = gen(F.softmax(fake2, dim=1), 3, tag, val) if result_detach: - return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach() + return fake1.detach(), fake2.detach(), fake3.detach(), fake0.detach() else: - return fake1, fake2, fake3, x_in + return fake1, fake2, fake3, fake0 def train(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") args = parse_arguments() - c_steps = 5 + c_steps = 2 g_steps = 1 # 训练阶段 train_stage = 1 mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练 stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程 + total_epoch = 0 ginka = GinkaModel().to(device) minamo = MinamoModel2().to(device) @@ -114,7 +121,7 @@ def train(): dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) - optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9)) + optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-4, betas=(0.0, 0.9)) # scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs) # scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs) @@ -134,9 +141,9 @@ def train(): ginka.load_state_dict(data_ginka["model_state"], strict=False) minamo.load_state_dict(data_minamo["model_state"], strict=False) - if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None: - c_steps = data_ginka["c_steps"] - g_steps = data_ginka["g_steps"] + # if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None: + # c_steps = data_ginka["c_steps"] + # g_steps = data_ginka["g_steps"] if data_ginka.get("mask_ratio") is not None: mask_ratio = data_ginka["mask_ratio"] @@ -147,6 +154,9 @@ def train(): if data_ginka.get("stage") is not None: train_stage = data_ginka["stage"] + if data_ginka.get("total_epoch") is not None: + total_epoch = data_ginka["data_ginka"] + if args.load_optim: if data_ginka.get("optim_state") is not None: optimizer_ginka.load_state_dict(data_ginka["optim_state"]) @@ -156,6 +166,7 @@ def train(): print("Train from loaded state.") curr_epoch = args.curr_epoch + first_curr = curr_epoch * 3 if args.tuning: train_stage = 1 @@ -182,6 +193,8 @@ def train(): loss_ce_total = torch.Tensor([0]).to(device) for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): + rand = batch["rand"].to(device) + real0 = batch["real0"].to(device) real1 = batch["real1"].to(device) masked1 = batch["masked1"].to(device) real2 = batch["real2"].to(device) @@ -200,23 +213,19 @@ def train(): with torch.no_grad(): if train_stage == 1 or train_stage == 2: fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) - elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) - - if train_stage == 4: - loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, masked2, x_in, tag_cond, val_cond) - + fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) + + if train_stage < 4: + fake0 = ginka(rand, 0, tag_cond, val_cond) + + loss_d0, dis0 = criterion.discriminator_loss(minamo, 0, real0, fake0, tag_cond, val_cond) loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1, tag_cond, val_cond) loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2, tag_cond, val_cond) loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3, tag_cond, val_cond) - dis = [dis1, dis2, dis3] - loss_d = [loss_d1, loss_d2, loss_d3] - - if train_stage == 4: - dis.append(dis0) - loss_d.append(loss_d0) + dis = [dis0, dis1, dis2, dis3] + loss_d = [loss_d0, loss_d1, loss_d2, loss_d3] dis_avg = sum(dis) / len(dis) loss_d_avg = sum(loss_d) / len(loss_d) @@ -237,33 +246,35 @@ def train(): if train_stage == 1 or train_stage == 2: fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, False) - loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond) - loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond) - loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond) + loss_g1, loss_ce_g1 = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1, tag_cond, val_cond) + loss_g2, loss_ce_g2 = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2, tag_cond, val_cond) + loss_g3, loss_ce_g3 = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3, tag_cond, val_cond) - loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0 + loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0 loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3) - loss_g.backward() - optimizer_ginka.step() - loss_total_ginka += loss_g.detach() loss_ce_total += loss_ce.detach() elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4) + fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, False, train_stage == 4) + if train_stage == 4: + fake0 = F.softmax(fake0, dim=1) - loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, x_in, tag_cond, val_cond) + loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, fake0, tag_cond, val_cond) loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1, tag_cond, val_cond) loss_g3 = criterion.generator_loss_total_with_input(minamo, 3, fake3, fake2, tag_cond, val_cond) - if train_stage == 4: - loss_head = criterion.generator_input_head_loss(minamo, x_in, tag_cond, val_cond) - loss_head.backward(retain_graph=True) + loss_g = (loss_g1 * 3.0 + loss_g2 + loss_g3) / 5.0 + + if train_stage < 4: + fake0 = F.softmax(ginka(rand, 0, tag_cond, val_cond), dim=1) - loss_g = (loss_g1 * 3 + loss_g2 + loss_g3) / 5.0 - loss_g.backward() - optimizer_ginka.step() - loss_total_ginka += loss_g.detach() + loss_g0 = criterion.generator_input_head_loss(minamo, fake0, tag_cond, val_cond) + loss_g += loss_g0 + + loss_g.backward() + optimizer_ginka.step() + loss_total_ginka += loss_g.detach() avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps @@ -311,8 +322,8 @@ def train(): fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3, x_in = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) - x_in = torch.argmax(x_in, dim=1).cpu().numpy() + fake1, fake2, fake3, fake0 = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) + fake0 = torch.argmax(fake0, dim=1).cpu().numpy() fake1 = torch.argmax(fake1, dim=1).cpu().numpy() fake2 = torch.argmax(fake2, dim=1).cpu().numpy() @@ -339,7 +350,7 @@ def train(): elif train_stage == 3 or train_stage == 4: vline = np.full((416, gap, 3), color, dtype=np.uint8) # 垂直分割线 hline = np.full((gap, 2 * 416 + gap, 3), color, dtype=np.uint8) # 水平分割线 - in_img = matrix_to_image_cv(x_in[i], tile_dict) + in_img = matrix_to_image_cv(fake0[i], tile_dict) img = np.block([ [[in_img], [vline], [fake1_img]], [[hline]], @@ -352,46 +363,42 @@ def train(): # 训练流程控制 - if mask_ratio < 0.5 and avg_loss_ce < 0.5: - low_loss_epochs += 1 - elif mask_ratio > 0.5 and avg_loss_ce < 0.5: - low_loss_epochs += 1 - else: - low_loss_epochs = 0 - if train_stage >= 2: - if (epoch + 1) % 5 == 1: + if (epoch + 1) % 10 == 1: train_stage = 3 - elif (epoch + 1) % 5 == 3: + elif (epoch + 1) % 10 == 3: train_stage = 4 - elif (epoch + 1) % 5 == 0: + elif (epoch + 1) % 10 == 0: train_stage = 2 - if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch: - if mask_ratio >= 0.9: - train_stage = 2 - mask_ratio += 0.2 - mask_ratio = min(mask_ratio, 0.9) - low_loss_epochs = 0 - stage_epoch = 0 + if train_stage == 1: + if (mask_ratio < 0.3 and stage_epoch >= first_curr) or \ + (mask_ratio > 0.3 and stage_epoch >= curr_epoch): + if mask_ratio >= 0.9: + train_stage = 2 + mask_ratio += 0.2 + mask_ratio = min(mask_ratio, 0.9) + low_loss_epochs = 0 + stage_epoch = 0 stage_epoch += 1 + total_epoch += 1 # scheduler_ginka.step() # scheduler_minamo.step() - if avg_dis < 0: - g_steps = max(int(-avg_dis * 5), 1) - else: - g_steps = 1 + # if avg_dis < 0: + # g_steps = max(int(-avg_dis * 5), 1) + # else: + # g_steps = 1 # if avg_loss_ginka > 0 and epoch > 20 and not args.resume: # g_steps += int(min(avg_loss_ginka * 5, 50)) - if avg_loss_minamo > 0: - c_steps = int(min(5 + avg_loss_minamo * 5, 15)) - else: - c_steps = 5 + # if avg_loss_minamo > 0: + # c_steps = int(min(5 + avg_loss_minamo * 5, 15)) + # else: + # c_steps = 5 dataset.train_stage = train_stage dataset_val.train_stage = train_stage