perf: 调优超参数和模型

This commit is contained in:
unanmed 2025-04-07 16:41:27 +08:00
parent 8130296e1f
commit 268b21e0b7
7 changed files with 58 additions and 37 deletions

View File

@ -63,12 +63,12 @@ class GinkaWGANDataset(Dataset):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
min_main = random.uniform(0.75, 0.9)
max_main = random.uniform(0.9, 1)
epsilon = random.uniform(0, 0.25)
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device)
# min_main = random.uniform(0.8, 0.9)
# max_main = random.uniform(0.9, 1)
# epsilon = random.uniform(0, 0.2)
# target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device)
return target_smooth
return target
class MinamoGANDataset(Dataset):
def __init__(self, refer_data_path):

View File

@ -311,7 +311,7 @@ def js_divergence(P, Q, epsilon=1e-10):
return js.mean() # 标量
class WGANGinkaLoss:
def __init__(self, lambda_gp=20, weight=[0.7, 0.2, 0.1], diversity_lamda=0):
def __init__(self, lambda_gp=50, weight=[0.7, 0.2, 0.1], diversity_lamda=0.2):
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
self.diversity_lamda = diversity_lamda
@ -361,6 +361,8 @@ class WGANGinkaLoss:
real_scores, _, _ = critic(real_data, real_graph)
fake_scores, _, _ = critic(fake_data, fake_graph)
# print("Critic 输出范围", fake_scores.min().item(), fake_scores.max().item(), real_scores.min().item(), real_scores.max().item())
# Wasserstein 距离
d_loss = fake_scores.mean() - real_scores.mean()
grad_loss = self.compute_gradient_penalty(critic, real_data, fake_data)
@ -381,10 +383,16 @@ class WGANGinkaLoss:
]
return sum(losses)
def diversity_loss(self, fake1, fake2):
fake1 = fake1[:, :, 1:-1, 1:-1]
fake2 = fake2[:, :, 1:-1, 1:-1]
return js_divergence(fake1, fake2)
def generator_loss(self, critic, fake1, fake2):
""" 生成器损失函数 """
loss1 = self.generator_loss_one(critic, fake1)
loss2 = self.generator_loss_one(critic, fake2)
return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * js_divergence(fake1, fake2)
return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * self.diversity_loss(fake1, fake2)

View File

@ -1,5 +1,6 @@
import argparse
import os
import sys
from datetime import datetime
import torch
import torch.optim as optim
@ -20,6 +21,8 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/wgan", exist_ok=True)
disable_tqdm = not sys.stdout.isatty()
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
@ -37,11 +40,13 @@ def clip_weights(model, clip_value=0.01):
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
c_steps = 1
g_steps = 4
args = parse_arguments()
# c_steps = 1 if args.resume else 5
# g_steps = 5 if args.resume else 1
c_steps = 5
g_steps = 1
ginka = GinkaModel()
minamo = MinamoScoreModule()
ginka.to(device)
@ -50,8 +55,8 @@ def train():
dataset = GinkaWGANDataset(args.train, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
optimizer_ginka = optim.RMSprop(ginka.parameters(), lr=2e-4)
optimizer_minamo = optim.RMSprop(minamo.parameters(), lr=1e-5)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
criterion = WGANGinkaLoss()
@ -68,12 +73,12 @@ def train():
minamo.load_state_dict(data["model_state"], strict=False)
print("Train from loaded state.")
for epoch in tqdm(range(args.epochs), desc="GAN Training"):
for epoch in tqdm(range(args.epochs), desc="GAN Training", disable=disable_tqdm):
loss_total_minamo = torch.Tensor([0]).to(device)
loss_total_ginka = torch.Tensor([0]).to(device)
dis_total = torch.Tensor([0]).to(device)
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress"):
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
batch_size = real_data.size(0)
real_data = real_data.to(device)
real_graph = batch_convert_soft_map_to_graph(real_data)
@ -92,7 +97,11 @@ def train():
# 反向传播
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
loss_d.backward()
# torch.nn.utils.clip_grad_norm_(minamo_vis.parameters(), max_norm=1.0)
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=1.0)
# total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2)
# print("Critic 梯度范数:", total_norm.item())
# print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item())
# print("Critic 输出范围:", d_real.min().item(), d_real.max().item())
optimizer_minamo.step()
loss_total_minamo += loss_d
@ -119,21 +128,27 @@ def train():
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | Wasserstein Loss: {avg_dis:.8f} | Loss Ginka: {avg_loss_ginka:.8f} | Loss Minamo: {avg_loss_minamo:.8f}"
)
if avg_dis < -9:
g_steps = 21
elif avg_dis < -6:
g_steps = 14
elif avg_dis < -3:
g_steps = 7
if avg_dis < 0:
g_steps = max(int(-avg_dis * 5), 1)
else:
g_steps = 1
# if avg_dis > 0:
# c_steps = min(max(int(avg_dis * 5), 1), 5)
# else:
# c_steps = 1
if avg_dis > 3:
c_steps = 3
else:
c_steps = 1
# if avg_loss_minamo > 0:
# c_steps += min(max(int(avg_loss_minamo * 3), 1), 5)
# else:
# c_steps += 0
# 每五轮输出一次图片,并保存检查点
# if avg_dis > 3:
# c_steps = 3
# else:
# c_steps = 1
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % 5 == 0:
# 输出 20 张图片,每批次 4 张,一共五批
idx = 0

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from .vision import MinamoVisionModel
from .topo import MinamoTopoModel
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
@ -60,14 +61,10 @@ class MinamoScoreModule(nn.Module):
self.vision_model = MinamoVisionModel(tile_types)
# 输出层
self.topo_fc = nn.Sequential(
nn.Linear(512, 2048),
nn.LeakyReLU(0.2),
nn.Linear(2048, 1)
spectral_norm(nn.Linear(512, 1)),
)
self.vision_fc = nn.Sequential(
nn.Linear(512, 2048),
nn.LeakyReLU(0.2),
nn.Linear(2048, 1)
spectral_norm(nn.Linear(512, 1)),
)
def forward(self, map, graph):

View File

@ -12,7 +12,7 @@ class MinamoTopoModel(nn.Module):
super().__init__()
# 传入 softmax 概率值,直接映射
self.input_proj = nn.Sequential(
nn.Linear(tile_types, emb_dim),
spectral_norm(nn.Linear(tile_types, emb_dim)),
nn.LeakyReLU(0.2)
)
# 图卷积层
@ -25,7 +25,7 @@ class MinamoTopoModel(nn.Module):
# self.norm3 = nn.LayerNorm(out_dim)
self.fc = nn.Sequential(
nn.Linear(out_dim, feat_dim),
spectral_norm(nn.Linear(out_dim, feat_dim)),
nn.LeakyReLU(0.2)
)

View File

@ -19,7 +19,8 @@ class MinamoVisionModel(nn.Module):
nn.Flatten()
)
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_ch*8*2*2, out_dim))
spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)),
nn.LeakyReLU(0.2)
)
def forward(self, x):

View File

@ -2,5 +2,5 @@ VIS_DIM = 512
TOPO_DIM = 512
FEAT_DIM = 1024
VISION_WEIGHT = 0
TOPO_WEIGHT = 1
VISION_WEIGHT = 0.3
TOPO_WEIGHT = 0.7