From 53041ab75469c7470fd408b1df31a7a673c1c41c Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 1 May 2025 22:08:39 +0800 Subject: [PATCH] =?UTF-8?q?perf:=20=E6=94=B9=E8=BF=9B=E7=BD=91=E7=BB=9C?= =?UTF-8?q?=E7=BB=93=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/common/cond.py | 36 ++++++++++------ ginka/critic/model.py | 10 ++--- ginka/dataset.py | 1 - ginka/generator/loss.py | 2 +- ginka/generator/model.py | 2 +- ginka/generator/output.py | 6 ++- ginka/generator/unet.py | 4 -- ginka/train_wgan.py | 91 +++++++++++++++++++-------------------- 8 files changed, 81 insertions(+), 71 deletions(-) diff --git a/ginka/common/cond.py b/ginka/common/cond.py index c2a3a9c..ac64021 100644 --- a/ginka/common/cond.py +++ b/ginka/common/cond.py @@ -7,37 +7,49 @@ class ConditionEncoder(nn.Module): super().__init__() self.tag_embed = nn.Linear(tag_dim, hidden_dim) self.val_embed = nn.Linear(val_dim, hidden_dim) + self.encoder = nn.TransformerEncoder( + nn.TransformerEncoderLayer( + d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4, + batch_first=True + ), + num_layers=6 + ) self.fusion = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim*2), nn.LayerNorm(hidden_dim*2), nn.ELU(), - nn.Linear(hidden_dim*2, hidden_dim*4), - nn.LayerNorm(hidden_dim*4), - nn.ELU(), - - nn.Linear(hidden_dim*4, out_dim) + nn.Linear(hidden_dim*2, out_dim) ) def forward(self, tag, val): tag = self.tag_embed(tag) val = self.val_embed(val) - feat = torch.cat([tag, val], dim=1) + feat = torch.stack([tag, val], dim=1) + feat = self.encoder(feat) + feat = torch.mean(feat, dim=1) feat = self.fusion(feat) return feat class ConditionInjector(nn.Module): def __init__(self, cond_dim, out_dim): super().__init__() - self.fc = nn.Sequential( + self.gamma_layer = nn.Sequential( nn.Linear(cond_dim, cond_dim*2), nn.LayerNorm(cond_dim*2), nn.ELU(), nn.Linear(cond_dim*2, out_dim) ) - + self.beta_layer = nn.Sequential( + nn.Linear(cond_dim, cond_dim*2), + nn.LayerNorm(cond_dim*2), + nn.ELU(), + + nn.Linear(cond_dim*2, out_dim) + ) + def forward(self, x, cond): - cond = self.fc(cond) - B, D = cond.shape - cond = cond.view(B, D, 1, 1) - return x + cond + gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3) + beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3) + return x * gamma + beta diff --git a/ginka/critic/model.py b/ginka/critic/model.py index dfe45a0..9ef80c3 100644 --- a/ginka/critic/model.py +++ b/ginka/critic/model.py @@ -2,12 +2,12 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import spectral_norm -from torch_geometric.nn import global_max_pool, GCNConv, global_mean_pool +from torch_geometric.nn import global_max_pool, GCNConv from shared.constant import VISION_WEIGHT, TOPO_WEIGHT from shared.graph import batch_convert_soft_map_to_graph from .vision import MinamoVisionModel from .topo import MinamoTopoModel -from ..common.cond import ConditionEncoder, ConditionInjector +from ..common.cond import ConditionEncoder def print_memory(tag=""): print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") @@ -24,7 +24,7 @@ class CNNHead(nn.Module): self.fc = nn.Sequential( spectral_norm(nn.Linear(in_ch*2*2, 1)) ) - self.proj = nn.Linear(256, in_ch*2*2) + self.proj = spectral_norm(nn.Linear(256, in_ch*2*2)) def forward(self, x, cond): x = self.cnn(x) @@ -39,7 +39,7 @@ class GCNHead(nn.Module): def __init__(self, in_dim): super().__init__() self.gcn = GCNConv(in_dim, in_dim) - self.proj = nn.Linear(256, in_dim) + self.proj = spectral_norm(nn.Linear(256, in_dim)) self.fc = nn.Sequential( spectral_norm(nn.Linear(in_dim, 1)) ) @@ -69,7 +69,7 @@ class MinamoModel(nn.Module): super().__init__() self.topo_model = MinamoTopoModel(tile_types) self.vision_model = MinamoVisionModel(tile_types) - self.cond = ConditionEncoder(64, 16, 128, 256) + self.cond = ConditionEncoder(64, 16, 256, 256) # 输出层 self.head1 = MinamoScoreHead(512, 512) self.head2 = MinamoScoreHead(512, 512) diff --git a/ginka/dataset.py b/ginka/dataset.py index c701c19..acc7823 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -51,7 +51,6 @@ def apply_curriculum_mask( mask_ratio: float # 遮挡比例 0~1 ) -> torch.Tensor: C, H, W = maps.shape - device = maps.device masked_maps = maps.clone() # Step 1: 移除不需要的类别(全设为 0 类) diff --git a/ginka/generator/loss.py b/ginka/generator/loss.py index 211bbf8..0e4edd6 100644 --- a/ginka/generator/loss.py +++ b/ginka/generator/loss.py @@ -347,7 +347,7 @@ def immutable_penalty_loss( return penalty class WGANGinkaLoss: - def __init__(self, lambda_gp=100, weight=[1, 0.5, 10, 0.2, 0.2, 0.2]): + def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2]): # weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失 self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight diff --git a/ginka/generator/model.py b/ginka/generator/model.py index fee210d..71871a9 100644 --- a/ginka/generator/model.py +++ b/ginka/generator/model.py @@ -15,7 +15,7 @@ class GinkaModel(nn.Module): """ super().__init__() self.head = RandomInputHead() - self.cond = ConditionEncoder(64, 16, 128, 256) + self.cond = ConditionEncoder(64, 16, 256, 256) self.input = GinkaInput(32, 32, (13, 13), (32, 32)) self.unet = GinkaUNet(32, base_ch, base_ch) self.output = GinkaOutput(base_ch, out_ch, (13, 13)) diff --git a/ginka/generator/output.py b/ginka/generator/output.py index 05802b6..f9fa6f8 100644 --- a/ginka/generator/output.py +++ b/ginka/generator/output.py @@ -10,7 +10,11 @@ class StageHead(nn.Module): self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32) self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch]) self.pool = nn.Sequential( - nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'), + nn.Conv2d(in_ch, in_ch*2, 3, padding=1, padding_mode='replicate'), + nn.InstanceNorm2d(in_ch*2), + nn.ELU(), + + nn.Conv2d(in_ch*2, in_ch, 3, padding=1, padding_mode='replicate'), nn.InstanceNorm2d(in_ch), nn.ELU(), diff --git a/ginka/generator/unet.py b/ginka/generator/unet.py index 057da1e..f7bea1d 100644 --- a/ginka/generator/unet.py +++ b/ginka/generator/unet.py @@ -167,10 +167,6 @@ class GinkaUNet(nn.Module): """Ginka Model UNet 部分 """ super().__init__() - # self.input = GinkaTransformerEncoder( - # in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size - # token_size=4, ff_dim=feat_dim*2, num_layers=4 - # ) self.down1 = ConvBlock(in_ch, base_ch) self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16) self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8) diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 10c73c9..48ada16 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -11,7 +11,6 @@ from tqdm import tqdm from .generator.model import GinkaModel from .dataset import GinkaWGANDataset from .generator.loss import WGANGinkaLoss -from .generator.input import RandomInputHead from .critic.model import MinamoModel from shared.image import matrix_to_image_cv @@ -106,13 +105,12 @@ def train(): stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程 ginka = GinkaModel().to(device) - ginka_head = RandomInputHead().to(device) minamo = MinamoModel().to(device) dataset = GinkaWGANDataset(args.train, device) dataset_val = GinkaWGANDataset(args.validate, device) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True) - dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True) + dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE) optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9)) @@ -270,47 +268,6 @@ def train(): f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}" ) - if avg_loss_ce < 0.5: - low_loss_epochs += 1 - else: - low_loss_epochs = 0 - - # 训练流程控制 - - if train_stage >= 2: - train_stage += 1 - - if train_stage == 5: - train_stage = 2 - - if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch: - if mask_ratio >= 0.9: - train_stage = 2 - mask_ratio += 0.2 - mask_ratio = min(mask_ratio, 0.9) - low_loss_epochs = 0 - stage_epoch = 0 - - stage_epoch += 1 - - dataset.train_stage = train_stage - dataset_val.train_stage = train_stage - dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio - dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio - - # scheduler_ginka.step() - # scheduler_minamo.step() - - if avg_dis < 0: - g_steps = max(int(-avg_dis * 5), 1) - else: - g_steps = 1 - - if avg_loss_minamo > 0: - c_steps = int(min(5 + avg_loss_minamo * 5, 15)) - else: - c_steps = 5 - # 每若干轮输出一次图片,并保存检查点 if (epoch + 1) % args.checkpoint == 0: # 保存检查点 @@ -344,8 +301,7 @@ def train(): fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True) elif train_stage == 3 or train_stage == 4: - input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1) - fake1, fake2, fake3, _ = gen_total(ginka, input, tag_cond, val_cond, True, True, train_stage == 4) + fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4) fake1 = torch.argmax(fake1, dim=1).cpu().numpy() fake2 = torch.argmax(fake2, dim=1).cpu().numpy() @@ -358,6 +314,49 @@ def train(): cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image) idx += 1 + + # 训练流程控制 + + if mask_ratio < 0.5 and avg_loss_ce < 0.2: + low_loss_epochs += 1 + elif mask_ratio > 0.5 and avg_loss_ce < 0.3: + low_loss_epochs += 1 + else: + low_loss_epochs = 0 + + if train_stage >= 2: + train_stage += 1 + + if train_stage == 5: + train_stage = 2 + + if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch: + if mask_ratio >= 0.9: + train_stage = 2 + mask_ratio += 0.2 + mask_ratio = min(mask_ratio, 0.9) + low_loss_epochs = 0 + stage_epoch = 0 + + stage_epoch += 1 + + # scheduler_ginka.step() + # scheduler_minamo.step() + + if avg_dis < 0: + g_steps = max(int(-avg_dis * 5), 1) + else: + g_steps = 1 + + if avg_loss_minamo > 0: + c_steps = int(min(5 + avg_loss_minamo * 5, 15)) + else: + c_steps = 5 + + dataset.train_stage = train_stage + dataset_val.train_stage = train_stage + dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio + dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio print("Train ended.") torch.save({