mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
perf: 改进网络结构
This commit is contained in:
parent
a7d21260e4
commit
53041ab754
@ -7,37 +7,49 @@ class ConditionEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
|
||||
self.val_embed = nn.Linear(val_dim, hidden_dim)
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
|
||||
batch_first=True
|
||||
),
|
||||
num_layers=6
|
||||
)
|
||||
self.fusion = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim*2),
|
||||
nn.LayerNorm(hidden_dim*2),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(hidden_dim*2, hidden_dim*4),
|
||||
nn.LayerNorm(hidden_dim*4),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(hidden_dim*4, out_dim)
|
||||
nn.Linear(hidden_dim*2, out_dim)
|
||||
)
|
||||
|
||||
def forward(self, tag, val):
|
||||
tag = self.tag_embed(tag)
|
||||
val = self.val_embed(val)
|
||||
feat = torch.cat([tag, val], dim=1)
|
||||
feat = torch.stack([tag, val], dim=1)
|
||||
feat = self.encoder(feat)
|
||||
feat = torch.mean(feat, dim=1)
|
||||
feat = self.fusion(feat)
|
||||
return feat
|
||||
|
||||
class ConditionInjector(nn.Module):
|
||||
def __init__(self, cond_dim, out_dim):
|
||||
super().__init__()
|
||||
self.fc = nn.Sequential(
|
||||
self.gamma_layer = nn.Sequential(
|
||||
nn.Linear(cond_dim, cond_dim*2),
|
||||
nn.LayerNorm(cond_dim*2),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(cond_dim*2, out_dim)
|
||||
)
|
||||
|
||||
self.beta_layer = nn.Sequential(
|
||||
nn.Linear(cond_dim, cond_dim*2),
|
||||
nn.LayerNorm(cond_dim*2),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Linear(cond_dim*2, out_dim)
|
||||
)
|
||||
|
||||
def forward(self, x, cond):
|
||||
cond = self.fc(cond)
|
||||
B, D = cond.shape
|
||||
cond = cond.view(B, D, 1, 1)
|
||||
return x + cond
|
||||
gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3)
|
||||
beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3)
|
||||
return x * gamma + beta
|
||||
|
||||
@ -2,12 +2,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch_geometric.nn import global_max_pool, GCNConv, global_mean_pool
|
||||
from torch_geometric.nn import global_max_pool, GCNConv
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
from .vision import MinamoVisionModel
|
||||
from .topo import MinamoTopoModel
|
||||
from ..common.cond import ConditionEncoder, ConditionInjector
|
||||
from ..common.cond import ConditionEncoder
|
||||
|
||||
def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||
@ -24,7 +24,7 @@ class CNNHead(nn.Module):
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_ch*2*2, 1))
|
||||
)
|
||||
self.proj = nn.Linear(256, in_ch*2*2)
|
||||
self.proj = spectral_norm(nn.Linear(256, in_ch*2*2))
|
||||
|
||||
def forward(self, x, cond):
|
||||
x = self.cnn(x)
|
||||
@ -39,7 +39,7 @@ class GCNHead(nn.Module):
|
||||
def __init__(self, in_dim):
|
||||
super().__init__()
|
||||
self.gcn = GCNConv(in_dim, in_dim)
|
||||
self.proj = nn.Linear(256, in_dim)
|
||||
self.proj = spectral_norm(nn.Linear(256, in_dim))
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_dim, 1))
|
||||
)
|
||||
@ -69,7 +69,7 @@ class MinamoModel(nn.Module):
|
||||
super().__init__()
|
||||
self.topo_model = MinamoTopoModel(tile_types)
|
||||
self.vision_model = MinamoVisionModel(tile_types)
|
||||
self.cond = ConditionEncoder(64, 16, 128, 256)
|
||||
self.cond = ConditionEncoder(64, 16, 256, 256)
|
||||
# 输出层
|
||||
self.head1 = MinamoScoreHead(512, 512)
|
||||
self.head2 = MinamoScoreHead(512, 512)
|
||||
|
||||
@ -51,7 +51,6 @@ def apply_curriculum_mask(
|
||||
mask_ratio: float # 遮挡比例 0~1
|
||||
) -> torch.Tensor:
|
||||
C, H, W = maps.shape
|
||||
device = maps.device
|
||||
masked_maps = maps.clone()
|
||||
|
||||
# Step 1: 移除不需要的类别(全设为 0 类)
|
||||
|
||||
@ -347,7 +347,7 @@ def immutable_penalty_loss(
|
||||
return penalty
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 10, 0.2, 0.2, 0.2]):
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2]):
|
||||
# weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
|
||||
@ -15,7 +15,7 @@ class GinkaModel(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.head = RandomInputHead()
|
||||
self.cond = ConditionEncoder(64, 16, 128, 256)
|
||||
self.cond = ConditionEncoder(64, 16, 256, 256)
|
||||
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
|
||||
self.unet = GinkaUNet(32, base_ch, base_ch)
|
||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||
|
||||
@ -10,7 +10,11 @@ class StageHead(nn.Module):
|
||||
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
|
||||
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
|
||||
self.pool = nn.Sequential(
|
||||
nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.Conv2d(in_ch, in_ch*2, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(in_ch*2),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Conv2d(in_ch*2, in_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(in_ch),
|
||||
nn.ELU(),
|
||||
|
||||
|
||||
@ -167,10 +167,6 @@ class GinkaUNet(nn.Module):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
super().__init__()
|
||||
# self.input = GinkaTransformerEncoder(
|
||||
# in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
|
||||
# token_size=4, ff_dim=feat_dim*2, num_layers=4
|
||||
# )
|
||||
self.down1 = ConvBlock(in_ch, base_ch)
|
||||
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
|
||||
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
|
||||
|
||||
@ -11,7 +11,6 @@ from tqdm import tqdm
|
||||
from .generator.model import GinkaModel
|
||||
from .dataset import GinkaWGANDataset
|
||||
from .generator.loss import WGANGinkaLoss
|
||||
from .generator.input import RandomInputHead
|
||||
from .critic.model import MinamoModel
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
@ -106,13 +105,12 @@ def train():
|
||||
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
||||
|
||||
ginka = GinkaModel().to(device)
|
||||
ginka_head = RandomInputHead().to(device)
|
||||
minamo = MinamoModel().to(device)
|
||||
|
||||
dataset = GinkaWGANDataset(args.train, device)
|
||||
dataset_val = GinkaWGANDataset(args.validate, device)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
||||
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
||||
@ -270,47 +268,6 @@ def train():
|
||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
|
||||
)
|
||||
|
||||
if avg_loss_ce < 0.5:
|
||||
low_loss_epochs += 1
|
||||
else:
|
||||
low_loss_epochs = 0
|
||||
|
||||
# 训练流程控制
|
||||
|
||||
if train_stage >= 2:
|
||||
train_stage += 1
|
||||
|
||||
if train_stage == 5:
|
||||
train_stage = 2
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
mask_ratio += 0.2
|
||||
mask_ratio = min(mask_ratio, 0.9)
|
||||
low_loss_epochs = 0
|
||||
stage_epoch = 0
|
||||
|
||||
stage_epoch += 1
|
||||
|
||||
dataset.train_stage = train_stage
|
||||
dataset_val.train_stage = train_stage
|
||||
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
||||
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
||||
|
||||
# scheduler_ginka.step()
|
||||
# scheduler_minamo.step()
|
||||
|
||||
if avg_dis < 0:
|
||||
g_steps = max(int(-avg_dis * 5), 1)
|
||||
else:
|
||||
g_steps = 1
|
||||
|
||||
if avg_loss_minamo > 0:
|
||||
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
||||
else:
|
||||
c_steps = 5
|
||||
|
||||
# 每若干轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
# 保存检查点
|
||||
@ -344,8 +301,7 @@ def train():
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
|
||||
fake1, fake2, fake3, _ = gen_total(ginka, input, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||
|
||||
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
||||
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
||||
@ -358,6 +314,49 @@ def train():
|
||||
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
|
||||
|
||||
idx += 1
|
||||
|
||||
# 训练流程控制
|
||||
|
||||
if mask_ratio < 0.5 and avg_loss_ce < 0.2:
|
||||
low_loss_epochs += 1
|
||||
elif mask_ratio > 0.5 and avg_loss_ce < 0.3:
|
||||
low_loss_epochs += 1
|
||||
else:
|
||||
low_loss_epochs = 0
|
||||
|
||||
if train_stage >= 2:
|
||||
train_stage += 1
|
||||
|
||||
if train_stage == 5:
|
||||
train_stage = 2
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
mask_ratio += 0.2
|
||||
mask_ratio = min(mask_ratio, 0.9)
|
||||
low_loss_epochs = 0
|
||||
stage_epoch = 0
|
||||
|
||||
stage_epoch += 1
|
||||
|
||||
# scheduler_ginka.step()
|
||||
# scheduler_minamo.step()
|
||||
|
||||
if avg_dis < 0:
|
||||
g_steps = max(int(-avg_dis * 5), 1)
|
||||
else:
|
||||
g_steps = 1
|
||||
|
||||
if avg_loss_minamo > 0:
|
||||
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
||||
else:
|
||||
c_steps = 5
|
||||
|
||||
dataset.train_stage = train_stage
|
||||
dataset_val.train_stage = train_stage
|
||||
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
||||
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
|
||||
Loading…
Reference in New Issue
Block a user