perf: 改进网络结构

This commit is contained in:
unanmed 2025-05-01 22:08:39 +08:00
parent a7d21260e4
commit 53041ab754
8 changed files with 81 additions and 71 deletions

View File

@ -7,37 +7,49 @@ class ConditionEncoder(nn.Module):
super().__init__()
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
self.val_embed = nn.Linear(val_dim, hidden_dim)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
batch_first=True
),
num_layers=6
)
self.fusion = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim*2),
nn.LayerNorm(hidden_dim*2),
nn.ELU(),
nn.Linear(hidden_dim*2, hidden_dim*4),
nn.LayerNorm(hidden_dim*4),
nn.ELU(),
nn.Linear(hidden_dim*4, out_dim)
nn.Linear(hidden_dim*2, out_dim)
)
def forward(self, tag, val):
tag = self.tag_embed(tag)
val = self.val_embed(val)
feat = torch.cat([tag, val], dim=1)
feat = torch.stack([tag, val], dim=1)
feat = self.encoder(feat)
feat = torch.mean(feat, dim=1)
feat = self.fusion(feat)
return feat
class ConditionInjector(nn.Module):
def __init__(self, cond_dim, out_dim):
super().__init__()
self.fc = nn.Sequential(
self.gamma_layer = nn.Sequential(
nn.Linear(cond_dim, cond_dim*2),
nn.LayerNorm(cond_dim*2),
nn.ELU(),
nn.Linear(cond_dim*2, out_dim)
)
self.beta_layer = nn.Sequential(
nn.Linear(cond_dim, cond_dim*2),
nn.LayerNorm(cond_dim*2),
nn.ELU(),
nn.Linear(cond_dim*2, out_dim)
)
def forward(self, x, cond):
cond = self.fc(cond)
B, D = cond.shape
cond = cond.view(B, D, 1, 1)
return x + cond
gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3)
beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3)
return x * gamma + beta

View File

@ -2,12 +2,12 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch_geometric.nn import global_max_pool, GCNConv, global_mean_pool
from torch_geometric.nn import global_max_pool, GCNConv
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from shared.graph import batch_convert_soft_map_to_graph
from .vision import MinamoVisionModel
from .topo import MinamoTopoModel
from ..common.cond import ConditionEncoder, ConditionInjector
from ..common.cond import ConditionEncoder
def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
@ -24,7 +24,7 @@ class CNNHead(nn.Module):
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_ch*2*2, 1))
)
self.proj = nn.Linear(256, in_ch*2*2)
self.proj = spectral_norm(nn.Linear(256, in_ch*2*2))
def forward(self, x, cond):
x = self.cnn(x)
@ -39,7 +39,7 @@ class GCNHead(nn.Module):
def __init__(self, in_dim):
super().__init__()
self.gcn = GCNConv(in_dim, in_dim)
self.proj = nn.Linear(256, in_dim)
self.proj = spectral_norm(nn.Linear(256, in_dim))
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_dim, 1))
)
@ -69,7 +69,7 @@ class MinamoModel(nn.Module):
super().__init__()
self.topo_model = MinamoTopoModel(tile_types)
self.vision_model = MinamoVisionModel(tile_types)
self.cond = ConditionEncoder(64, 16, 128, 256)
self.cond = ConditionEncoder(64, 16, 256, 256)
# 输出层
self.head1 = MinamoScoreHead(512, 512)
self.head2 = MinamoScoreHead(512, 512)

View File

@ -51,7 +51,6 @@ def apply_curriculum_mask(
mask_ratio: float # 遮挡比例 0~1
) -> torch.Tensor:
C, H, W = maps.shape
device = maps.device
masked_maps = maps.clone()
# Step 1: 移除不需要的类别(全设为 0 类)

View File

@ -347,7 +347,7 @@ def immutable_penalty_loss(
return penalty
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.5, 10, 0.2, 0.2, 0.2]):
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2]):
# weight: 判别器损失CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight

View File

@ -15,7 +15,7 @@ class GinkaModel(nn.Module):
"""
super().__init__()
self.head = RandomInputHead()
self.cond = ConditionEncoder(64, 16, 128, 256)
self.cond = ConditionEncoder(64, 16, 256, 256)
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
self.unet = GinkaUNet(32, base_ch, base_ch)
self.output = GinkaOutput(base_ch, out_ch, (13, 13))

View File

@ -10,7 +10,11 @@ class StageHead(nn.Module):
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
self.pool = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'),
nn.Conv2d(in_ch, in_ch*2, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(in_ch*2),
nn.ELU(),
nn.Conv2d(in_ch*2, in_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(in_ch),
nn.ELU(),

View File

@ -167,10 +167,6 @@ class GinkaUNet(nn.Module):
"""Ginka Model UNet 部分
"""
super().__init__()
# self.input = GinkaTransformerEncoder(
# in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
# token_size=4, ff_dim=feat_dim*2, num_layers=4
# )
self.down1 = ConvBlock(in_ch, base_ch)
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)

View File

@ -11,7 +11,6 @@ from tqdm import tqdm
from .generator.model import GinkaModel
from .dataset import GinkaWGANDataset
from .generator.loss import WGANGinkaLoss
from .generator.input import RandomInputHead
from .critic.model import MinamoModel
from shared.image import matrix_to_image_cv
@ -106,13 +105,12 @@ def train():
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
ginka = GinkaModel().to(device)
ginka_head = RandomInputHead().to(device)
minamo = MinamoModel().to(device)
dataset = GinkaWGANDataset(args.train, device)
dataset_val = GinkaWGANDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
@ -270,47 +268,6 @@ def train():
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
)
if avg_loss_ce < 0.5:
low_loss_epochs += 1
else:
low_loss_epochs = 0
# 训练流程控制
if train_stage >= 2:
train_stage += 1
if train_stage == 5:
train_stage = 2
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
stage_epoch = 0
stage_epoch += 1
dataset.train_stage = train_stage
dataset_val.train_stage = train_stage
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
# scheduler_ginka.step()
# scheduler_minamo.step()
if avg_dis < 0:
g_steps = max(int(-avg_dis * 5), 1)
else:
g_steps = 1
if avg_loss_minamo > 0:
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
else:
c_steps = 5
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0:
# 保存检查点
@ -344,8 +301,7 @@ def train():
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
elif train_stage == 3 or train_stage == 4:
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
fake1, fake2, fake3, _ = gen_total(ginka, input, tag_cond, val_cond, True, True, train_stage == 4)
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
@ -358,6 +314,49 @@ def train():
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
idx += 1
# 训练流程控制
if mask_ratio < 0.5 and avg_loss_ce < 0.2:
low_loss_epochs += 1
elif mask_ratio > 0.5 and avg_loss_ce < 0.3:
low_loss_epochs += 1
else:
low_loss_epochs = 0
if train_stage >= 2:
train_stage += 1
if train_stage == 5:
train_stage = 2
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
stage_epoch = 0
stage_epoch += 1
# scheduler_ginka.step()
# scheduler_minamo.step()
if avg_dis < 0:
g_steps = max(int(-avg_dis * 5), 1)
else:
g_steps = 1
if avg_loss_minamo > 0:
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
else:
c_steps = 5
dataset.train_stage = train_stage
dataset_val.train_stage = train_stage
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
print("Train ended.")
torch.save({