mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-24 21:31:10 +08:00
perf: 改进网络结构
This commit is contained in:
parent
a7d21260e4
commit
53041ab754
@ -7,28 +7,41 @@ class ConditionEncoder(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
|
self.tag_embed = nn.Linear(tag_dim, hidden_dim)
|
||||||
self.val_embed = nn.Linear(val_dim, hidden_dim)
|
self.val_embed = nn.Linear(val_dim, hidden_dim)
|
||||||
|
self.encoder = nn.TransformerEncoder(
|
||||||
|
nn.TransformerEncoderLayer(
|
||||||
|
d_model=hidden_dim, nhead=8, dim_feedforward=hidden_dim*4,
|
||||||
|
batch_first=True
|
||||||
|
),
|
||||||
|
num_layers=6
|
||||||
|
)
|
||||||
self.fusion = nn.Sequential(
|
self.fusion = nn.Sequential(
|
||||||
|
nn.Linear(hidden_dim, hidden_dim*2),
|
||||||
nn.LayerNorm(hidden_dim*2),
|
nn.LayerNorm(hidden_dim*2),
|
||||||
nn.ELU(),
|
nn.ELU(),
|
||||||
|
|
||||||
nn.Linear(hidden_dim*2, hidden_dim*4),
|
nn.Linear(hidden_dim*2, out_dim)
|
||||||
nn.LayerNorm(hidden_dim*4),
|
|
||||||
nn.ELU(),
|
|
||||||
|
|
||||||
nn.Linear(hidden_dim*4, out_dim)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, tag, val):
|
def forward(self, tag, val):
|
||||||
tag = self.tag_embed(tag)
|
tag = self.tag_embed(tag)
|
||||||
val = self.val_embed(val)
|
val = self.val_embed(val)
|
||||||
feat = torch.cat([tag, val], dim=1)
|
feat = torch.stack([tag, val], dim=1)
|
||||||
|
feat = self.encoder(feat)
|
||||||
|
feat = torch.mean(feat, dim=1)
|
||||||
feat = self.fusion(feat)
|
feat = self.fusion(feat)
|
||||||
return feat
|
return feat
|
||||||
|
|
||||||
class ConditionInjector(nn.Module):
|
class ConditionInjector(nn.Module):
|
||||||
def __init__(self, cond_dim, out_dim):
|
def __init__(self, cond_dim, out_dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.fc = nn.Sequential(
|
self.gamma_layer = nn.Sequential(
|
||||||
|
nn.Linear(cond_dim, cond_dim*2),
|
||||||
|
nn.LayerNorm(cond_dim*2),
|
||||||
|
nn.ELU(),
|
||||||
|
|
||||||
|
nn.Linear(cond_dim*2, out_dim)
|
||||||
|
)
|
||||||
|
self.beta_layer = nn.Sequential(
|
||||||
nn.Linear(cond_dim, cond_dim*2),
|
nn.Linear(cond_dim, cond_dim*2),
|
||||||
nn.LayerNorm(cond_dim*2),
|
nn.LayerNorm(cond_dim*2),
|
||||||
nn.ELU(),
|
nn.ELU(),
|
||||||
@ -37,7 +50,6 @@ class ConditionInjector(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x, cond):
|
def forward(self, x, cond):
|
||||||
cond = self.fc(cond)
|
gamma = self.gamma_layer(cond).unsqueeze(2).unsqueeze(3)
|
||||||
B, D = cond.shape
|
beta = self.beta_layer(cond).unsqueeze(2).unsqueeze(3)
|
||||||
cond = cond.view(B, D, 1, 1)
|
return x * gamma + beta
|
||||||
return x + cond
|
|
||||||
|
|||||||
@ -2,12 +2,12 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.utils import spectral_norm
|
from torch.nn.utils import spectral_norm
|
||||||
from torch_geometric.nn import global_max_pool, GCNConv, global_mean_pool
|
from torch_geometric.nn import global_max_pool, GCNConv
|
||||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||||
from shared.graph import batch_convert_soft_map_to_graph
|
from shared.graph import batch_convert_soft_map_to_graph
|
||||||
from .vision import MinamoVisionModel
|
from .vision import MinamoVisionModel
|
||||||
from .topo import MinamoTopoModel
|
from .topo import MinamoTopoModel
|
||||||
from ..common.cond import ConditionEncoder, ConditionInjector
|
from ..common.cond import ConditionEncoder
|
||||||
|
|
||||||
def print_memory(tag=""):
|
def print_memory(tag=""):
|
||||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||||
@ -24,7 +24,7 @@ class CNNHead(nn.Module):
|
|||||||
self.fc = nn.Sequential(
|
self.fc = nn.Sequential(
|
||||||
spectral_norm(nn.Linear(in_ch*2*2, 1))
|
spectral_norm(nn.Linear(in_ch*2*2, 1))
|
||||||
)
|
)
|
||||||
self.proj = nn.Linear(256, in_ch*2*2)
|
self.proj = spectral_norm(nn.Linear(256, in_ch*2*2))
|
||||||
|
|
||||||
def forward(self, x, cond):
|
def forward(self, x, cond):
|
||||||
x = self.cnn(x)
|
x = self.cnn(x)
|
||||||
@ -39,7 +39,7 @@ class GCNHead(nn.Module):
|
|||||||
def __init__(self, in_dim):
|
def __init__(self, in_dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gcn = GCNConv(in_dim, in_dim)
|
self.gcn = GCNConv(in_dim, in_dim)
|
||||||
self.proj = nn.Linear(256, in_dim)
|
self.proj = spectral_norm(nn.Linear(256, in_dim))
|
||||||
self.fc = nn.Sequential(
|
self.fc = nn.Sequential(
|
||||||
spectral_norm(nn.Linear(in_dim, 1))
|
spectral_norm(nn.Linear(in_dim, 1))
|
||||||
)
|
)
|
||||||
@ -69,7 +69,7 @@ class MinamoModel(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.topo_model = MinamoTopoModel(tile_types)
|
self.topo_model = MinamoTopoModel(tile_types)
|
||||||
self.vision_model = MinamoVisionModel(tile_types)
|
self.vision_model = MinamoVisionModel(tile_types)
|
||||||
self.cond = ConditionEncoder(64, 16, 128, 256)
|
self.cond = ConditionEncoder(64, 16, 256, 256)
|
||||||
# 输出层
|
# 输出层
|
||||||
self.head1 = MinamoScoreHead(512, 512)
|
self.head1 = MinamoScoreHead(512, 512)
|
||||||
self.head2 = MinamoScoreHead(512, 512)
|
self.head2 = MinamoScoreHead(512, 512)
|
||||||
|
|||||||
@ -51,7 +51,6 @@ def apply_curriculum_mask(
|
|||||||
mask_ratio: float # 遮挡比例 0~1
|
mask_ratio: float # 遮挡比例 0~1
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
C, H, W = maps.shape
|
C, H, W = maps.shape
|
||||||
device = maps.device
|
|
||||||
masked_maps = maps.clone()
|
masked_maps = maps.clone()
|
||||||
|
|
||||||
# Step 1: 移除不需要的类别(全设为 0 类)
|
# Step 1: 移除不需要的类别(全设为 0 类)
|
||||||
|
|||||||
@ -347,7 +347,7 @@ def immutable_penalty_loss(
|
|||||||
return penalty
|
return penalty
|
||||||
|
|
||||||
class WGANGinkaLoss:
|
class WGANGinkaLoss:
|
||||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 10, 0.2, 0.2, 0.2]):
|
def __init__(self, lambda_gp=100, weight=[1, 0.5, 50, 0.2, 0.2, 0.2]):
|
||||||
# weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
|
# weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
|
||||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
|
|||||||
@ -15,7 +15,7 @@ class GinkaModel(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.head = RandomInputHead()
|
self.head = RandomInputHead()
|
||||||
self.cond = ConditionEncoder(64, 16, 128, 256)
|
self.cond = ConditionEncoder(64, 16, 256, 256)
|
||||||
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
|
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
|
||||||
self.unet = GinkaUNet(32, base_ch, base_ch)
|
self.unet = GinkaUNet(32, base_ch, base_ch)
|
||||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||||
|
|||||||
@ -10,7 +10,11 @@ class StageHead(nn.Module):
|
|||||||
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
|
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
|
||||||
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
|
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
|
||||||
self.pool = nn.Sequential(
|
self.pool = nn.Sequential(
|
||||||
nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'),
|
nn.Conv2d(in_ch, in_ch*2, 3, padding=1, padding_mode='replicate'),
|
||||||
|
nn.InstanceNorm2d(in_ch*2),
|
||||||
|
nn.ELU(),
|
||||||
|
|
||||||
|
nn.Conv2d(in_ch*2, in_ch, 3, padding=1, padding_mode='replicate'),
|
||||||
nn.InstanceNorm2d(in_ch),
|
nn.InstanceNorm2d(in_ch),
|
||||||
nn.ELU(),
|
nn.ELU(),
|
||||||
|
|
||||||
|
|||||||
@ -167,10 +167,6 @@ class GinkaUNet(nn.Module):
|
|||||||
"""Ginka Model UNet 部分
|
"""Ginka Model UNet 部分
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# self.input = GinkaTransformerEncoder(
|
|
||||||
# in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
|
|
||||||
# token_size=4, ff_dim=feat_dim*2, num_layers=4
|
|
||||||
# )
|
|
||||||
self.down1 = ConvBlock(in_ch, base_ch)
|
self.down1 = ConvBlock(in_ch, base_ch)
|
||||||
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
|
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
|
||||||
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
|
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from tqdm import tqdm
|
|||||||
from .generator.model import GinkaModel
|
from .generator.model import GinkaModel
|
||||||
from .dataset import GinkaWGANDataset
|
from .dataset import GinkaWGANDataset
|
||||||
from .generator.loss import WGANGinkaLoss
|
from .generator.loss import WGANGinkaLoss
|
||||||
from .generator.input import RandomInputHead
|
|
||||||
from .critic.model import MinamoModel
|
from .critic.model import MinamoModel
|
||||||
from shared.image import matrix_to_image_cv
|
from shared.image import matrix_to_image_cv
|
||||||
|
|
||||||
@ -106,13 +105,12 @@ def train():
|
|||||||
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
||||||
|
|
||||||
ginka = GinkaModel().to(device)
|
ginka = GinkaModel().to(device)
|
||||||
ginka_head = RandomInputHead().to(device)
|
|
||||||
minamo = MinamoModel().to(device)
|
minamo = MinamoModel().to(device)
|
||||||
|
|
||||||
dataset = GinkaWGANDataset(args.train, device)
|
dataset = GinkaWGANDataset(args.train, device)
|
||||||
dataset_val = GinkaWGANDataset(args.validate, device)
|
dataset_val = GinkaWGANDataset(args.validate, device)
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
||||||
|
|
||||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
||||||
@ -270,47 +268,6 @@ def train():
|
|||||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
|
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if avg_loss_ce < 0.5:
|
|
||||||
low_loss_epochs += 1
|
|
||||||
else:
|
|
||||||
low_loss_epochs = 0
|
|
||||||
|
|
||||||
# 训练流程控制
|
|
||||||
|
|
||||||
if train_stage >= 2:
|
|
||||||
train_stage += 1
|
|
||||||
|
|
||||||
if train_stage == 5:
|
|
||||||
train_stage = 2
|
|
||||||
|
|
||||||
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
|
|
||||||
if mask_ratio >= 0.9:
|
|
||||||
train_stage = 2
|
|
||||||
mask_ratio += 0.2
|
|
||||||
mask_ratio = min(mask_ratio, 0.9)
|
|
||||||
low_loss_epochs = 0
|
|
||||||
stage_epoch = 0
|
|
||||||
|
|
||||||
stage_epoch += 1
|
|
||||||
|
|
||||||
dataset.train_stage = train_stage
|
|
||||||
dataset_val.train_stage = train_stage
|
|
||||||
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
|
||||||
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
|
||||||
|
|
||||||
# scheduler_ginka.step()
|
|
||||||
# scheduler_minamo.step()
|
|
||||||
|
|
||||||
if avg_dis < 0:
|
|
||||||
g_steps = max(int(-avg_dis * 5), 1)
|
|
||||||
else:
|
|
||||||
g_steps = 1
|
|
||||||
|
|
||||||
if avg_loss_minamo > 0:
|
|
||||||
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
|
||||||
else:
|
|
||||||
c_steps = 5
|
|
||||||
|
|
||||||
# 每若干轮输出一次图片,并保存检查点
|
# 每若干轮输出一次图片,并保存检查点
|
||||||
if (epoch + 1) % args.checkpoint == 0:
|
if (epoch + 1) % args.checkpoint == 0:
|
||||||
# 保存检查点
|
# 保存检查点
|
||||||
@ -344,8 +301,7 @@ def train():
|
|||||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
|
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, tag_cond, val_cond, True)
|
||||||
|
|
||||||
elif train_stage == 3 or train_stage == 4:
|
elif train_stage == 3 or train_stage == 4:
|
||||||
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
|
fake1, fake2, fake3, _ = gen_total(ginka, masked1, tag_cond, val_cond, True, True, train_stage == 4)
|
||||||
fake1, fake2, fake3, _ = gen_total(ginka, input, tag_cond, val_cond, True, True, train_stage == 4)
|
|
||||||
|
|
||||||
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
||||||
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
||||||
@ -359,6 +315,49 @@ def train():
|
|||||||
|
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
|
# 训练流程控制
|
||||||
|
|
||||||
|
if mask_ratio < 0.5 and avg_loss_ce < 0.2:
|
||||||
|
low_loss_epochs += 1
|
||||||
|
elif mask_ratio > 0.5 and avg_loss_ce < 0.3:
|
||||||
|
low_loss_epochs += 1
|
||||||
|
else:
|
||||||
|
low_loss_epochs = 0
|
||||||
|
|
||||||
|
if train_stage >= 2:
|
||||||
|
train_stage += 1
|
||||||
|
|
||||||
|
if train_stage == 5:
|
||||||
|
train_stage = 2
|
||||||
|
|
||||||
|
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= curr_epoch:
|
||||||
|
if mask_ratio >= 0.9:
|
||||||
|
train_stage = 2
|
||||||
|
mask_ratio += 0.2
|
||||||
|
mask_ratio = min(mask_ratio, 0.9)
|
||||||
|
low_loss_epochs = 0
|
||||||
|
stage_epoch = 0
|
||||||
|
|
||||||
|
stage_epoch += 1
|
||||||
|
|
||||||
|
# scheduler_ginka.step()
|
||||||
|
# scheduler_minamo.step()
|
||||||
|
|
||||||
|
if avg_dis < 0:
|
||||||
|
g_steps = max(int(-avg_dis * 5), 1)
|
||||||
|
else:
|
||||||
|
g_steps = 1
|
||||||
|
|
||||||
|
if avg_loss_minamo > 0:
|
||||||
|
c_steps = int(min(5 + avg_loss_minamo * 5, 15))
|
||||||
|
else:
|
||||||
|
c_steps = 5
|
||||||
|
|
||||||
|
dataset.train_stage = train_stage
|
||||||
|
dataset_val.train_stage = train_stage
|
||||||
|
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
||||||
|
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
||||||
|
|
||||||
print("Train ended.")
|
print("Train ended.")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": ginka.state_dict(),
|
"model_state": ginka.state_dict(),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user