mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 08:21:11 +08:00
perf: 调优超参数和模型
This commit is contained in:
parent
8130296e1f
commit
268b21e0b7
@ -63,12 +63,12 @@ class GinkaWGANDataset(Dataset):
|
||||
item = self.data[idx]
|
||||
|
||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
min_main = random.uniform(0.75, 0.9)
|
||||
max_main = random.uniform(0.9, 1)
|
||||
epsilon = random.uniform(0, 0.25)
|
||||
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device)
|
||||
# min_main = random.uniform(0.8, 0.9)
|
||||
# max_main = random.uniform(0.9, 1)
|
||||
# epsilon = random.uniform(0, 0.2)
|
||||
# target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device)
|
||||
|
||||
return target_smooth
|
||||
return target
|
||||
|
||||
class MinamoGANDataset(Dataset):
|
||||
def __init__(self, refer_data_path):
|
||||
|
||||
@ -311,7 +311,7 @@ def js_divergence(P, Q, epsilon=1e-10):
|
||||
return js.mean() # 标量
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=20, weight=[0.7, 0.2, 0.1], diversity_lamda=0):
|
||||
def __init__(self, lambda_gp=50, weight=[0.7, 0.2, 0.1], diversity_lamda=0.2):
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
self.diversity_lamda = diversity_lamda
|
||||
@ -361,6 +361,8 @@ class WGANGinkaLoss:
|
||||
real_scores, _, _ = critic(real_data, real_graph)
|
||||
fake_scores, _, _ = critic(fake_data, fake_graph)
|
||||
|
||||
# print("Critic 输出范围", fake_scores.min().item(), fake_scores.max().item(), real_scores.min().item(), real_scores.max().item())
|
||||
|
||||
# Wasserstein 距离
|
||||
d_loss = fake_scores.mean() - real_scores.mean()
|
||||
grad_loss = self.compute_gradient_penalty(critic, real_data, fake_data)
|
||||
@ -381,10 +383,16 @@ class WGANGinkaLoss:
|
||||
]
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def diversity_loss(self, fake1, fake2):
|
||||
fake1 = fake1[:, :, 1:-1, 1:-1]
|
||||
fake2 = fake2[:, :, 1:-1, 1:-1]
|
||||
|
||||
return js_divergence(fake1, fake2)
|
||||
|
||||
def generator_loss(self, critic, fake1, fake2):
|
||||
""" 生成器损失函数 """
|
||||
loss1 = self.generator_loss_one(critic, fake1)
|
||||
loss2 = self.generator_loss_one(critic, fake2)
|
||||
|
||||
return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * js_divergence(fake1, fake2)
|
||||
return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * self.diversity_loss(fake1, fake2)
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
@ -20,6 +21,8 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/wgan", exist_ok=True)
|
||||
|
||||
disable_tqdm = not sys.stdout.isatty()
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
@ -37,11 +40,13 @@ def clip_weights(model, clip_value=0.01):
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
c_steps = 1
|
||||
g_steps = 4
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
# c_steps = 1 if args.resume else 5
|
||||
# g_steps = 5 if args.resume else 1
|
||||
c_steps = 5
|
||||
g_steps = 1
|
||||
|
||||
ginka = GinkaModel()
|
||||
minamo = MinamoScoreModule()
|
||||
ginka.to(device)
|
||||
@ -50,8 +55,8 @@ def train():
|
||||
dataset = GinkaWGANDataset(args.train, device)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
|
||||
|
||||
optimizer_ginka = optim.RMSprop(ginka.parameters(), lr=2e-4)
|
||||
optimizer_minamo = optim.RMSprop(minamo.parameters(), lr=1e-5)
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
||||
|
||||
criterion = WGANGinkaLoss()
|
||||
|
||||
@ -68,12 +73,12 @@ def train():
|
||||
minamo.load_state_dict(data["model_state"], strict=False)
|
||||
print("Train from loaded state.")
|
||||
|
||||
for epoch in tqdm(range(args.epochs), desc="GAN Training"):
|
||||
for epoch in tqdm(range(args.epochs), desc="GAN Training", disable=disable_tqdm):
|
||||
loss_total_minamo = torch.Tensor([0]).to(device)
|
||||
loss_total_ginka = torch.Tensor([0]).to(device)
|
||||
dis_total = torch.Tensor([0]).to(device)
|
||||
|
||||
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress"):
|
||||
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
batch_size = real_data.size(0)
|
||||
real_data = real_data.to(device)
|
||||
real_graph = batch_convert_soft_map_to_graph(real_data)
|
||||
@ -92,7 +97,11 @@ def train():
|
||||
# 反向传播
|
||||
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
|
||||
loss_d.backward()
|
||||
# torch.nn.utils.clip_grad_norm_(minamo_vis.parameters(), max_norm=1.0)
|
||||
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=1.0)
|
||||
# total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2)
|
||||
# print("Critic 梯度范数:", total_norm.item())
|
||||
# print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item())
|
||||
# print("Critic 输出范围:", d_real.min().item(), d_real.max().item())
|
||||
optimizer_minamo.step()
|
||||
|
||||
loss_total_minamo += loss_d
|
||||
@ -119,21 +128,27 @@ def train():
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | Wasserstein Loss: {avg_dis:.8f} | Loss Ginka: {avg_loss_ginka:.8f} | Loss Minamo: {avg_loss_minamo:.8f}"
|
||||
)
|
||||
|
||||
if avg_dis < -9:
|
||||
g_steps = 21
|
||||
elif avg_dis < -6:
|
||||
g_steps = 14
|
||||
elif avg_dis < -3:
|
||||
g_steps = 7
|
||||
if avg_dis < 0:
|
||||
g_steps = max(int(-avg_dis * 5), 1)
|
||||
else:
|
||||
g_steps = 1
|
||||
|
||||
# if avg_dis > 0:
|
||||
# c_steps = min(max(int(avg_dis * 5), 1), 5)
|
||||
# else:
|
||||
# c_steps = 1
|
||||
|
||||
if avg_dis > 3:
|
||||
c_steps = 3
|
||||
else:
|
||||
c_steps = 1
|
||||
# if avg_loss_minamo > 0:
|
||||
# c_steps += min(max(int(avg_loss_minamo * 3), 1), 5)
|
||||
# else:
|
||||
# c_steps += 0
|
||||
|
||||
# 每五轮输出一次图片,并保存检查点
|
||||
# if avg_dis > 3:
|
||||
# c_steps = 3
|
||||
# else:
|
||||
# c_steps = 1
|
||||
|
||||
# 每若干轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % 5 == 0:
|
||||
# 输出 20 张图片,每批次 4 张,一共五批
|
||||
idx = 0
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from .vision import MinamoVisionModel
|
||||
from .topo import MinamoTopoModel
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
@ -60,14 +61,10 @@ class MinamoScoreModule(nn.Module):
|
||||
self.vision_model = MinamoVisionModel(tile_types)
|
||||
# 输出层
|
||||
self.topo_fc = nn.Sequential(
|
||||
nn.Linear(512, 2048),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(2048, 1)
|
||||
spectral_norm(nn.Linear(512, 1)),
|
||||
)
|
||||
self.vision_fc = nn.Sequential(
|
||||
nn.Linear(512, 2048),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(2048, 1)
|
||||
spectral_norm(nn.Linear(512, 1)),
|
||||
)
|
||||
|
||||
def forward(self, map, graph):
|
||||
|
||||
@ -12,7 +12,7 @@ class MinamoTopoModel(nn.Module):
|
||||
super().__init__()
|
||||
# 传入 softmax 概率值,直接映射
|
||||
self.input_proj = nn.Sequential(
|
||||
nn.Linear(tile_types, emb_dim),
|
||||
spectral_norm(nn.Linear(tile_types, emb_dim)),
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
# 图卷积层
|
||||
@ -25,7 +25,7 @@ class MinamoTopoModel(nn.Module):
|
||||
# self.norm3 = nn.LayerNorm(out_dim)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(out_dim, feat_dim),
|
||||
spectral_norm(nn.Linear(out_dim, feat_dim)),
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
|
||||
|
||||
@ -19,7 +19,8 @@ class MinamoVisionModel(nn.Module):
|
||||
nn.Flatten()
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_ch*8*2*2, out_dim))
|
||||
spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)),
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
@ -2,5 +2,5 @@ VIS_DIM = 512
|
||||
TOPO_DIM = 512
|
||||
FEAT_DIM = 1024
|
||||
|
||||
VISION_WEIGHT = 0
|
||||
TOPO_WEIGHT = 1
|
||||
VISION_WEIGHT = 0.3
|
||||
TOPO_WEIGHT = 0.7
|
||||
|
||||
Loading…
Reference in New Issue
Block a user