diff --git a/ginka/dataset.py b/ginka/dataset.py index 265beb0..c1abe53 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -63,12 +63,12 @@ class GinkaWGANDataset(Dataset): item = self.data[idx] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - min_main = random.uniform(0.75, 0.9) - max_main = random.uniform(0.9, 1) - epsilon = random.uniform(0, 0.25) - target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device) + # min_main = random.uniform(0.8, 0.9) + # max_main = random.uniform(0.9, 1) + # epsilon = random.uniform(0, 0.2) + # target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device) - return target_smooth + return target class MinamoGANDataset(Dataset): def __init__(self, refer_data_path): diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 7bfb05c..5b004e5 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -311,7 +311,7 @@ def js_divergence(P, Q, epsilon=1e-10): return js.mean() # 标量 class WGANGinkaLoss: - def __init__(self, lambda_gp=20, weight=[0.7, 0.2, 0.1], diversity_lamda=0): + def __init__(self, lambda_gp=50, weight=[0.7, 0.2, 0.1], diversity_lamda=0.2): self.lambda_gp = lambda_gp # 梯度惩罚系数 self.weight = weight self.diversity_lamda = diversity_lamda @@ -361,6 +361,8 @@ class WGANGinkaLoss: real_scores, _, _ = critic(real_data, real_graph) fake_scores, _, _ = critic(fake_data, fake_graph) + # print("Critic 输出范围", fake_scores.min().item(), fake_scores.max().item(), real_scores.min().item(), real_scores.max().item()) + # Wasserstein 距离 d_loss = fake_scores.mean() - real_scores.mean() grad_loss = self.compute_gradient_penalty(critic, real_data, fake_data) @@ -381,10 +383,16 @@ class WGANGinkaLoss: ] return sum(losses) + + def diversity_loss(self, fake1, fake2): + fake1 = fake1[:, :, 1:-1, 1:-1] + fake2 = fake2[:, :, 1:-1, 1:-1] + + return js_divergence(fake1, fake2) def generator_loss(self, critic, fake1, fake2): """ 生成器损失函数 """ loss1 = self.generator_loss_one(critic, fake1) loss2 = self.generator_loss_one(critic, fake2) - return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * js_divergence(fake1, fake2) + return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * self.diversity_loss(fake1, fake2) diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index 8093946..019cddb 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -1,5 +1,6 @@ import argparse import os +import sys from datetime import datetime import torch import torch.optim as optim @@ -20,6 +21,8 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) os.makedirs("result/wgan", exist_ok=True) +disable_tqdm = not sys.stdout.isatty() + def parse_arguments(): parser = argparse.ArgumentParser(description="training codes") parser.add_argument("--resume", type=bool, default=False) @@ -37,11 +40,13 @@ def clip_weights(model, clip_value=0.01): def train(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.") - c_steps = 1 - g_steps = 4 - args = parse_arguments() + # c_steps = 1 if args.resume else 5 + # g_steps = 5 if args.resume else 1 + c_steps = 5 + g_steps = 1 + ginka = GinkaModel() minamo = MinamoScoreModule() ginka.to(device) @@ -50,8 +55,8 @@ def train(): dataset = GinkaWGANDataset(args.train, device) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True) - optimizer_ginka = optim.RMSprop(ginka.parameters(), lr=2e-4) - optimizer_minamo = optim.RMSprop(minamo.parameters(), lr=1e-5) + optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) + optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9)) criterion = WGANGinkaLoss() @@ -68,12 +73,12 @@ def train(): minamo.load_state_dict(data["model_state"], strict=False) print("Train from loaded state.") - for epoch in tqdm(range(args.epochs), desc="GAN Training"): + for epoch in tqdm(range(args.epochs), desc="GAN Training", disable=disable_tqdm): loss_total_minamo = torch.Tensor([0]).to(device) loss_total_ginka = torch.Tensor([0]).to(device) dis_total = torch.Tensor([0]).to(device) - for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress"): + for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): batch_size = real_data.size(0) real_data = real_data.to(device) real_graph = batch_convert_soft_map_to_graph(real_data) @@ -92,7 +97,11 @@ def train(): # 反向传播 dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data) loss_d.backward() - # torch.nn.utils.clip_grad_norm_(minamo_vis.parameters(), max_norm=1.0) + # torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=1.0) + # total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2) + # print("Critic 梯度范数:", total_norm.item()) + # print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item()) + # print("Critic 输出范围:", d_real.min().item(), d_real.max().item()) optimizer_minamo.step() loss_total_minamo += loss_d @@ -119,21 +128,27 @@ def train(): f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | Wasserstein Loss: {avg_dis:.8f} | Loss Ginka: {avg_loss_ginka:.8f} | Loss Minamo: {avg_loss_minamo:.8f}" ) - if avg_dis < -9: - g_steps = 21 - elif avg_dis < -6: - g_steps = 14 - elif avg_dis < -3: - g_steps = 7 + if avg_dis < 0: + g_steps = max(int(-avg_dis * 5), 1) else: g_steps = 1 + + # if avg_dis > 0: + # c_steps = min(max(int(avg_dis * 5), 1), 5) + # else: + # c_steps = 1 - if avg_dis > 3: - c_steps = 3 - else: - c_steps = 1 + # if avg_loss_minamo > 0: + # c_steps += min(max(int(avg_loss_minamo * 3), 1), 5) + # else: + # c_steps += 0 - # 每五轮输出一次图片,并保存检查点 + # if avg_dis > 3: + # c_steps = 3 + # else: + # c_steps = 1 + + # 每若干轮输出一次图片,并保存检查点 if (epoch + 1) % 5 == 0: # 输出 20 张图片,每批次 4 张,一共五批 idx = 0 diff --git a/minamo/model/model.py b/minamo/model/model.py index f392953..0de0350 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.utils import spectral_norm from .vision import MinamoVisionModel from .topo import MinamoTopoModel from shared.constant import VISION_WEIGHT, TOPO_WEIGHT @@ -60,14 +61,10 @@ class MinamoScoreModule(nn.Module): self.vision_model = MinamoVisionModel(tile_types) # 输出层 self.topo_fc = nn.Sequential( - nn.Linear(512, 2048), - nn.LeakyReLU(0.2), - nn.Linear(2048, 1) + spectral_norm(nn.Linear(512, 1)), ) self.vision_fc = nn.Sequential( - nn.Linear(512, 2048), - nn.LeakyReLU(0.2), - nn.Linear(2048, 1) + spectral_norm(nn.Linear(512, 1)), ) def forward(self, map, graph): diff --git a/minamo/model/topo.py b/minamo/model/topo.py index bbf9d2d..a16c54f 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -12,7 +12,7 @@ class MinamoTopoModel(nn.Module): super().__init__() # 传入 softmax 概率值,直接映射 self.input_proj = nn.Sequential( - nn.Linear(tile_types, emb_dim), + spectral_norm(nn.Linear(tile_types, emb_dim)), nn.LeakyReLU(0.2) ) # 图卷积层 @@ -25,7 +25,7 @@ class MinamoTopoModel(nn.Module): # self.norm3 = nn.LayerNorm(out_dim) self.fc = nn.Sequential( - nn.Linear(out_dim, feat_dim), + spectral_norm(nn.Linear(out_dim, feat_dim)), nn.LeakyReLU(0.2) ) diff --git a/minamo/model/vision.py b/minamo/model/vision.py index 0ba7a25..50c2c45 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -19,7 +19,8 @@ class MinamoVisionModel(nn.Module): nn.Flatten() ) self.fc = nn.Sequential( - spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)) + spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)), + nn.LeakyReLU(0.2) ) def forward(self, x): diff --git a/shared/constant.py b/shared/constant.py index 846a53c..ed07a65 100644 --- a/shared/constant.py +++ b/shared/constant.py @@ -2,5 +2,5 @@ VIS_DIM = 512 TOPO_DIM = 512 FEAT_DIM = 1024 -VISION_WEIGHT = 0 -TOPO_WEIGHT = 1 +VISION_WEIGHT = 0.3 +TOPO_WEIGHT = 0.7