diff --git a/ginka/dataset.py b/ginka/dataset.py index 21c0906..10b7d6e 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -83,12 +83,12 @@ class GinkaWGANDataset(Dataset): self.mask_ratio1 = 0.1 self.mask_ratio2 = 0.1 self.mask_ratio3 = 0.1 - self.random_ratio = 0.0 def __len__(self): return len(self.data) def handle_stage1(self, target): + # 课程学习第一阶段,蒙版填充 removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1) removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2) removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3) @@ -96,33 +96,30 @@ class GinkaWGANDataset(Dataset): return removed1, masked1, removed2, masked2, removed3, masked3 def handle_stage2(self, target): + # 课程学习第二阶段,完全随机蒙版 removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) # 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可 removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1)) removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1)) - - if self.random_ratio > 0: - rd = random.uniform(0, self.random_ratio) - masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd) - masked2 = random_smooth_onehot(masked2, min_main=1 - rd, max_main=1.0, epsilon=rd) - masked3 = random_smooth_onehot(masked3, min_main=1 - rd, max_main=1.0, epsilon=rd) return removed1, masked1, removed2, masked2, removed3, masked3 def handle_stage3(self, target): + # 第三阶段,联合生成,输入随机蒙版 rd = random.uniform(0, self.random_ratio) removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9)) removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd) return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target) - + def handle_stage4(self, target): - input1 = torch.rand((32, 13, 13)) + # 第四阶段,与第二阶段交替进行,完全随机输入 removed1 = apply_curriculum_remove(target, STAGE1_REMOVE) removed2 = apply_curriculum_remove(target, STAGE2_REMOVE) removed3 = apply_curriculum_remove(target, STAGE3_REMOVE) - return removed1, input1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target) + rand = torch.rand(32, 32, 32, device=target.device) + return removed1, rand, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target) def __getitem__(self, idx): item = self.data[idx] @@ -137,19 +134,9 @@ class GinkaWGANDataset(Dataset): elif self.train_stage == 3: return self.handle_stage3(target) - + elif self.train_stage == 4: - self.mask_ratio1 = self.mask_ratio2 = self.mask_ratio3 = random.uniform(0, 0.9) - self.random_ratio = 0.2 - mode = random.choices([1, 2, 3, 4], weights=[0.2, 0.2, 0.2, 0.4]) - if mode == 1: - return self.handle_stage1(target) - elif mode == 2: - return self.handle_stage2(target) - elif mode == 3: - return self.handle_stage3(target) - else: - return self.handle_stage4(target) + return self.handle_stage4(target) raise RuntimeError(f"Invalid train stage: {self.train_stage}") \ No newline at end of file diff --git a/ginka/model/common.py b/ginka/model/common.py new file mode 100644 index 0000000..0121583 --- /dev/null +++ b/ginka/model/common.py @@ -0,0 +1,64 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import GCNConv +from torch_geometric.utils import grid + +class DoubleConvBlock(nn.Module): + def __init__(self, feats: tuple[int, int, int]): + super().__init__() + self.cnn = nn.Sequential( + nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(feats[1]), + nn.ELU(), + + nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(feats[2]), + nn.ELU(), + ) + + def forward(self, x): + x = self.cnn(x) + return x + +class GCNBlock(nn.Module): + def __init__(self, in_ch, hidden_ch, out_ch, w, h): + super().__init__() + self.conv1 = GCNConv(in_ch, hidden_ch) + self.conv2 = GCNConv(hidden_ch, out_ch) + self.norm1 = nn.LayerNorm(hidden_ch) + self.norm2 = nn.LayerNorm(out_ch) + self.single_edge_index, _ = grid(h, w) # [2, E] for a single map + + def forward(self, x): + # x: [B, C, H, W] + B, C, H, W = x.shape + + # Reshape to [B * H * W, C] + x = x.permute(0, 2, 3, 1).reshape(B * H * W, C) + + # Construct batched edge index + device = x.device + edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W) + + # Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling) + # batch = torch.arange(B, device=device).repeat_interleave(H * W) + + # GCN forward + x = self.conv1(x, edge_index) + x = F.elu(self.norm1(x)) + x = self.conv2(x, edge_index) + x = F.elu(self.norm2(x)) + + # Reshape back to [B, C, H, W] + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + return x + + def _batch_edge_index(self, B, edge_index, num_nodes_per_batch): + # 批次偏移 edge_index + edge_index = edge_index.clone() # [2, E] + batch_edge_index = [] + for i in range(B): + offset = i * num_nodes_per_batch + batch_edge_index.append(edge_index + offset) + return torch.cat(batch_edge_index, dim=1) \ No newline at end of file diff --git a/ginka/model/input.py b/ginka/model/input.py index 1d13496..554aec3 100644 --- a/ginka/model/input.py +++ b/ginka/model/input.py @@ -1,6 +1,31 @@ import torch import torch.nn as nn +class RandomInputHead(nn.Module): + def __init__(self, in_size=(32, 32), out_size=(32, 32)): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(1, 32, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(32), + nn.ELU(), + + nn.Conv2d(32, 64, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(64), + nn.ELU(), + + nn.Conv2d(64, 128, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(128), + nn.ELU(), + ) + self.out_conv = nn.Sequential( + nn.Conv2d(128, 32, 1), + ) + + def forward(self, x): + x = self.conv(x) + x = self.out_conv(x) + return x + class GinkaInput(nn.Module): def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)): super().__init__() diff --git a/ginka/model/loss.py b/ginka/model/loss.py index b2a0431..b5ad2dd 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -64,7 +64,7 @@ def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11]): return loss_unallowed -def inner_constraint_loss(pred: torch.Tensor, allowed=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12]): +def inner_constraint_loss(pred: torch.Tensor, allowed=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13]): """限定内部允许出现的图块种类 Args: @@ -235,6 +235,21 @@ def adaptive_count_loss( return total_loss +def input_head_illegal_loss(input_map, allowed_classes=(0, 1)): + C = input_map.shape[1] + mask = torch.ones(C, device=input_map.device) + mask[list(allowed_classes)] = 0 # 屏蔽允许的类别,其余为 1 + illegal_class_penalty = (input_map * mask.view(1, -1, 1, 1)).sum() / input_map.numel() + + return illegal_class_penalty + +def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1): + wall_prob = input_map[:, wall_class] # [B, H, W] + wall_ratio = wall_prob.mean() # 计算平均墙体占比 + wall_penalty = torch.clamp(wall_ratio - max_wall_ratio, min=0.0) # 超过则惩罚 + + return wall_penalty + class GinkaLoss(nn.Module): def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]): """Ginka Model 损失函数部分 @@ -310,7 +325,7 @@ def js_divergence(p, q, eps=1e-6, softmax=False): kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m) kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m) - return 0.5 * (kl_pm + kl_qm) + return torch.clamp(0.5 * (kl_pm + kl_qm), max=10) def immutable_penalty_loss( pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int] @@ -322,7 +337,6 @@ def immutable_penalty_loss( input: 模型输出 [B, C, H, W],概率分布 (softmax 后) target: 原始输入图 [B, C, H, W],概率分布 (softmax 后) modifiable_classes: 允许被修改的类别列表 - penalty_weight: 对非允许修改区域的惩罚系数 """ not_allowed = get_not_allowed(modifiable_classes, include_illegal=True) input_mask = pred[:, not_allowed, :, :] @@ -330,14 +344,17 @@ def immutable_penalty_loss( target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1) target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float() - # 差异区域(模型试图改变的地方) - penalty = F.l1_loss(input_mask, target_mask) + target_mask = torch.log(target_mask + 1e-6) # 转换为 log 概率分布 + input_mask = torch.log(input_mask + 1e-6) # 转换为 log 概率分布 - return penalty + # 差异区域(模型试图改变的地方) + penalty = F.kl_div(input_mask, target_mask, reduction='batchmean', log_target=True) + + return torch.clamp(penalty, max=1) class WGANGinkaLoss: - def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.01]): - # weight: 判别器损失,L1 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失 + def __init__(self, lambda_gp=100, weight=[1, 0.5, 10, 0.2, 0.2, 0.2]): + # weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失 self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight @@ -402,18 +419,18 @@ class WGANGinkaLoss: fake_scores, _, _ = critic(probs_fake, fake_graph, stage) minamo_loss = -torch.mean(fake_scores) - ce_loss = F.cross_entropy(fake, real) + ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage]) constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake) - fake_a, fake_b = fake.chunk(2, dim=0) + # fake_a, fake_b = fake.chunk(2, dim=0) losses = [ minamo_loss * self.weight[0], - ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小 + ce_loss * self.weight[1], # 蒙版越大,交叉熵损失权重越小 immutable_loss * self.weight[2], constraint_loss * self.weight[3], - -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], + # -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], ] if stage == 1: @@ -433,12 +450,12 @@ class WGANGinkaLoss: minamo_loss = -torch.mean(fake_scores) constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake) - fake_a, fake_b = fake.chunk(2, dim=0) + # fake_a, fake_b = fake.chunk(2, dim=0) losses = [ minamo_loss * self.weight[0], constraint_loss * self.weight[3], - -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], + # -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], ] if stage == 1: @@ -457,13 +474,13 @@ class WGANGinkaLoss: immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage]) constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake) - fake_a, fake_b = fake.chunk(2, dim=0) + # fake_a, fake_b = fake.chunk(2, dim=0) losses = [ minamo_loss * self.weight[0], immutable_loss * self.weight[2], constraint_loss * self.weight[3], - -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], + # -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5], ] if stage == 1: @@ -472,3 +489,14 @@ class WGANGinkaLoss: losses.append(entrance_loss * self.weight[4]) return sum(losses) + + def generator_input_head_loss(self, probs: torch.Tensor) -> torch.Tensor: + probs_a, probs_b = probs.chunk(2, dim=0) + + losses = [ + input_head_illegal_loss(probs), + input_head_wall_loss(probs), + -js_divergence(probs_a, probs_b, softmax=False) * 0.2 + ] + + return sum(losses) diff --git a/ginka/model/output.py b/ginka/model/output.py index d93931c..f59f8af 100644 --- a/ginka/model/output.py +++ b/ginka/model/output.py @@ -1,25 +1,23 @@ import torch import torch.nn as nn +from .common import GCNBlock, DoubleConvBlock class StageHead(nn.Module): def __init__(self, in_ch, out_ch, out_size=(13, 13)): super().__init__() - self.head = nn.Sequential( - nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'), - nn.InstanceNorm2d(in_ch), - nn.ELU(), - - nn.Conv2d(in_ch, in_ch, 1), - nn.InstanceNorm2d(in_ch), - nn.ELU(), - ) + self.cnn_head = DoubleConvBlock([in_ch, in_ch*2, in_ch]) + self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32) + self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch]) self.pool = nn.Sequential( nn.AdaptiveMaxPool2d(out_size), nn.Conv2d(in_ch, out_ch, 1) ) def forward(self, x): - x = self.head(x) + x_cnn = self.cnn_head(x) + x_gcn = self.gcn_head(x) + x = torch.cat([x_cnn, x_gcn], dim=1) + x = self.fusion(x) x = self.pool(x) return x diff --git a/ginka/model/unet.py b/ginka/model/unet.py index d72211b..7d5dfa9 100644 --- a/ginka/model/unet.py +++ b/ginka/model/unet.py @@ -1,9 +1,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -from torch_geometric.nn import GCNConv -from torch_geometric.utils import grid from shared.attention import ChannelAttention +from .common import GCNBlock, DoubleConvBlock class GinkaTransformerEncoder(nn.Module): def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6): @@ -35,7 +34,7 @@ class GinkaTransformerEncoder(nn.Module): return x class ConvBlock(nn.Module): - def __init__(self, in_ch, out_ch, atte=True): + def __init__(self, in_ch, out_ch, attn=True): super().__init__() self.conv = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'), @@ -44,63 +43,17 @@ class ConvBlock(nn.Module): nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'), nn.InstanceNorm2d(out_ch), ) - if atte: + if attn: self.conv.append(ChannelAttention(out_ch)) self.conv.append(nn.ELU()) def forward(self, x): return self.conv(x) -class GCNBlock(nn.Module): - def __init__(self, in_ch, hidden_ch, out_ch, w, h): - super().__init__() - self.conv1 = GCNConv(in_ch, hidden_ch) - self.conv2 = GCNConv(hidden_ch, out_ch) - self.norm1 = nn.LayerNorm(hidden_ch) - self.norm2 = nn.LayerNorm(out_ch) - self.single_edge_index, _ = grid(h, w) # [2, E] for a single map - - def forward(self, x): - # x: [B, C, H, W] - B, C, H, W = x.shape - - # Reshape to [B * H * W, C] - x = x.permute(0, 2, 3, 1).reshape(B * H * W, C) - - # Construct batched edge index - device = x.device - edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W) - - # Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling) - # batch = torch.arange(B, device=device).repeat_interleave(H * W) - - # GCN forward - x = self.conv1(x, edge_index) - x = F.elu(self.norm1(x)) - x = self.conv2(x, edge_index) - x = F.elu(self.norm2(x)) - - # Reshape back to [B, C, H, W] - x = x.view(B, H, W, -1).permute(0, 3, 1, 2) - return x - - def _batch_edge_index(self, B, edge_index, num_nodes_per_batch): - # 批次偏移 edge_index - edge_index = edge_index.clone() # [2, E] - batch_edge_index = [] - for i in range(B): - offset = i * num_nodes_per_batch - batch_edge_index.append(edge_index + offset) - return torch.cat(batch_edge_index, dim=1) - class FusionModule(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() - self.conv = nn.Sequential( - nn.Conv2d(in_ch, out_ch, 1), - nn.InstanceNorm2d(out_ch), - nn.ELU() - ) + self.conv = DoubleConvBlock([in_ch, out_ch, out_ch]) def forward(self, x1, x2): x = torch.cat([x1, x2], dim=1) diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 2392f34..edf603c 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -4,17 +4,16 @@ import sys from datetime import datetime import torch import torch.optim as optim +import torch.nn.functional as F import cv2 from torch_geometric.loader import DataLoader from tqdm import tqdm from .model.model import GinkaModel from .dataset import GinkaWGANDataset from .model.loss import WGANGinkaLoss +from .model.input import RandomInputHead from minamo.model.model import MinamoScoreModule -from minamo.model.similarity import MinamoSimilarityModel -from shared.graph import batch_convert_soft_map_to_graph from shared.image import matrix_to_image_cv -from shared.constant import VISION_WEIGHT, TOPO_WEIGHT BATCH_SIZE = 16 @@ -67,17 +66,15 @@ def train(): c_steps = 5 g_steps = 1 - # 1 代表课程学习阶段,2 代表课程学习后,逐渐转为联合学习的阶段 - # 3 代表课程学习后的联合遮挡学习阶段,4 代表最后随机输入的联合学习阶段 + # 训练阶段 train_stage = 1 + last_stage = False mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练 - random_ratio = 0 - stage3_epoch = 0 # 第三阶段 epoch 数,若干轮后进入第四阶段 + stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程 - ginka = GinkaModel() - minamo = MinamoScoreModule() - ginka.to(device) - minamo.to(device) + ginka = GinkaModel().to(device) + ginka_head = RandomInputHead().to(device) + minamo = MinamoScoreModule().to(device) dataset = GinkaWGANDataset(args.train, device) dataset_val = GinkaWGANDataset(args.validate, device) @@ -85,6 +82,7 @@ def train(): dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True) optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) + optimizer_head = optim.Adam(ginka_head.parameters(), lr=1e-4, betas=(0.0, 0.9)) optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9)) # scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs) @@ -112,15 +110,15 @@ def train(): if data_ginka.get("mask_ratio") is not None: mask_ratio = data_ginka["mask_ratio"] - if data_ginka.get("random_ratio") is not None: - random_ratio = data_ginka["random_ratio"] - - if data_ginka.get("stage_epoch3") is not None: - stage3_epoch = data_ginka["stage_epoch3"] + if data_ginka.get("stage_epoch") is not None: + stage_epoch = data_ginka["stage_epoch"] if data_ginka.get("stage") is not None: train_stage = data_ginka["stage"] + if data_ginka.get("last_stage") is not None: + last_stage = data_ginka["last_stage"] + if args.load_optim: if data_ginka.get("optim_state") is not None: optimizer_ginka.load_state_dict(data_ginka["optim_state"]) @@ -131,13 +129,11 @@ def train(): dataset.mask_ratio1 = mask_ratio dataset.mask_ratio2 = mask_ratio dataset.mask_ratio3 = mask_ratio - dataset.random_ratio = random_ratio dataset_val.train_stage = train_stage dataset_val.mask_ratio1 = mask_ratio dataset_val.mask_ratio2 = mask_ratio dataset_val.mask_ratio3 = mask_ratio - dataset_val.random_ratio = random_ratio print("Train from loaded state.") @@ -152,16 +148,34 @@ def train(): for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch] + if train_stage == 4: + # 最后一个阶段训练输入头 + count = 5 if stage_epoch <= 20 else 2 + for _ in range(count): + optimizer_head.zero_grad() + output = F.softmax(ginka_head(masked1), dim=1) + loss_head = criterion.generator_input_head_loss(output) + loss_head.backward() + optimizer_head.step() + # ---------- 训练判别器 for _ in range(c_steps): # 生成假样本 optimizer_minamo.zero_grad() optimizer_ginka.zero_grad() - if train_stage == 1 or train_stage == 2: - fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True) + optimizer_head.zero_grad() + + with torch.no_grad(): + if train_stage == 1 or train_stage == 2: + fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True) + + elif train_stage == 3: + fake1, fake2, fake3 = gen_total(ginka, masked1, True, True) - elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3 = gen_total(ginka, masked1, True, True) + elif train_stage == 4: + input = F.softmax(ginka_head(masked1), dim=1) + fake1, fake2, fake3 = gen_total(ginka, input, True, True) + loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1) loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2) @@ -183,6 +197,7 @@ def train(): for _ in range(g_steps): optimizer_minamo.zero_grad() optimizer_ginka.zero_grad() + optimizer_head.zero_grad() if train_stage == 1 or train_stage == 2: fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False) @@ -199,10 +214,12 @@ def train(): loss_ce_total += loss_ce.detach() elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3 = gen_total(ginka, masked1, True, False) + input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1) + + fake1, fake2, fake3 = gen_total(ginka, input, True, False) if train_stage == 3: - loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1) + loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, input) else: loss_g1 = criterion.generator_loss_total(minamo, 1, fake1) loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1) @@ -221,43 +238,42 @@ def train(): f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " + f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " + f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " + - f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | R: {random_ratio:.1f}" + f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}" ) - if avg_loss_ce < 0.5: + if avg_loss_ce < 1.0: low_loss_epochs += 1 else: low_loss_epochs = 0 - - if low_loss_epochs >= 3 and train_stage == 2: - if random_ratio >= 0.5: - train_stage = 3 - random_ratio += 0.2 - random_ratio = min(random_ratio, 0.5) - low_loss_epochs = 0 if low_loss_epochs >= 3 and train_stage == 1: if mask_ratio >= 0.9: train_stage = 2 + stage_epoch = 0 mask_ratio += 0.2 mask_ratio = min(mask_ratio, 0.9) low_loss_epochs = 0 - if train_stage == 3: - stage3_epoch += 1 - # 十轮足够了 - if stage3_epoch >= 10: - train_stage = 4 - stage3_epoch = 0 + if train_stage == 3 or train_stage == 2: + if stage_epoch >= 25: + train_stage += 1 + stage_epoch = 0 - if train_stage >= 2: - # 第二阶段后 L1 损失不再应该生效 + if train_stage >= 3: + # 第三阶段后交叉熵损失不再应该生效 mask_ratio = 1.0 + if last_stage: + if train_stage == 2 and stage_epoch % 5 == 0: + train_stage = 4 + + if train_stage == 4 and stage_epoch % 5 == 1: + train_stage = 2 + + stage_epoch += 1 + dataset.train_stage = train_stage dataset_val.train_stage = train_stage - dataset.random_ratio = random_ratio - dataset_val.random_ratio = random_ratio dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio @@ -284,8 +300,8 @@ def train(): "g_steps": g_steps, "stage": train_stage, "mask_ratio": mask_ratio, - "random_ratio": random_ratio, - "stage3_epoch": stage3_epoch, + "stage_epoch": stage_epoch, + "last_stage": last_stage }, f"result/wgan/ginka-{epoch + 1}.pth") torch.save({ "model_state": minamo.state_dict(), @@ -300,7 +316,8 @@ def train(): fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True) elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3 = gen_total(ginka, masked1, True, True) + input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1) + fake1, fake2, fake3 = gen_total(ginka, input, True, True) fake1 = torch.argmax(fake1, dim=1).cpu().numpy() fake2 = torch.argmax(fake2, dim=1).cpu().numpy() diff --git a/tiles/13.png b/tiles/13.png new file mode 100644 index 0000000..4b8d3a6 Binary files /dev/null and b/tiles/13.png differ diff --git a/train.sh b/train.sh index de59a05..152c69e 100644 --- a/train.sh +++ b/train.sh @@ -1,4 +1,4 @@ # 从头训练 python3 -u -m ginka.train_wgan >> output.log # 接续训练 -python3 -u -m ginka.train_wgan --resume true --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log \ No newline at end of file +python3 -u -m ginka.train_wgan --resume true --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log