diff --git a/cycle.sh b/cycle.sh deleted file mode 100644 index accd0e7..0000000 --- a/cycle.sh +++ /dev/null @@ -1,7 +0,0 @@ -i=$1 -while true -do - sh gan.sh "$i" - i=$((i+1)) - echo "第 $i 次循环完成" -done diff --git a/cycle2.sh b/cycle2.sh deleted file mode 100644 index 84011e9..0000000 --- a/cycle2.sh +++ /dev/null @@ -1,7 +0,0 @@ -start=$1 -end=$2 -for ((i=start; i<=end; i=i+1)) -do - sh gan.sh "$i" - echo "第 $i 次循环完成" -done diff --git a/gan.sh b/gan.sh deleted file mode 100644 index 5065d5c..0000000 --- a/gan.sh +++ /dev/null @@ -1,17 +0,0 @@ -# 训练部分 -python3 -m minamo.train --epochs 10 --resume true -python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json" -python3 -m minamo.train --epochs 10 --resume true -python3 -m ginka.train --epochs 30 --resume true -python3 -m ginka.validate -# 训练完毕,处理数据 -mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json" -mv "minamo-eval.json" "datasets/minamo-eval-$1.json" -cd data -pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:30 -pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10 -pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json" -pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json" -pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json" -pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json" -cd .. diff --git a/ginka/model/loss.py b/ginka/model/loss.py index c283218..efe65a0 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -325,7 +325,7 @@ def js_divergence(p, q, eps=1e-6, softmax=False): kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m) kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m) - return torch.clamp(0.5 * (kl_pm + kl_qm), max=10) + return torch.log1p(0.5 * (kl_pm + kl_qm)) def immutable_penalty_loss( pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int] @@ -334,8 +334,8 @@ def immutable_penalty_loss( 惩罚模型修改不可更改区域的损失。 Args: - input: 模型输出 [B, C, H, W],概率分布 (softmax 后) - target: 原始输入图 [B, C, H, W],概率分布 (softmax 后) + input: 模型输出 [B, C, H, W],概率分布 (softmax 前) + target: 原始输入图 [B, C, H, W],概率分布 (softmax 前) modifiable_classes: 允许被修改的类别列表 """ not_allowed = get_not_allowed(modifiable_classes, include_illegal=True) @@ -344,13 +344,10 @@ def immutable_penalty_loss( target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1) target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float() - target_mask = torch.log(target_mask + 1e-6) # 转换为 log 概率分布 - input_mask = torch.log(input_mask + 1e-6) # 转换为 log 概率分布 - # 差异区域(模型试图改变的地方) - penalty = F.kl_div(input_mask, target_mask, reduction='batchmean', log_target=True) + penalty = F.cross_entropy(input_mask, target_mask) - return torch.clamp(penalty, max=1) + return penalty class WGANGinkaLoss: def __init__(self, lambda_gp=100, weight=[1, 0.5, 10, 0.2, 0.2, 0.2]): @@ -420,7 +417,7 @@ class WGANGinkaLoss: fake_scores, _, _ = critic(probs_fake, fake_graph, stage) minamo_loss = -torch.mean(fake_scores) ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小 - immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage]) + immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake) fake_a, fake_b = fake.chunk(2, dim=0) @@ -471,7 +468,7 @@ class WGANGinkaLoss: fake_scores, _, _ = critic(probs_fake, fake_graph, stage) minamo_loss = -torch.mean(fake_scores) - immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage]) + immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage]) constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake) fake_a, fake_b = fake.chunk(2, dim=0) @@ -496,7 +493,7 @@ class WGANGinkaLoss: losses = [ input_head_illegal_loss(probs), input_head_wall_loss(probs), - -js_divergence(probs_a, probs_b, softmax=False) * 0.2 + -js_divergence(probs_a, probs_b, softmax=False) * 0.3 ] return sum(losses) diff --git a/ginka/model/model.py b/ginka/model/model.py index 02ee7fc..7fddb70 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -20,7 +20,7 @@ class GinkaModel(nn.Module): def forward(self, x, stage, random=False): if random: - x_in = F.softmax(self.head(x)) + x_in = F.softmax(self.head(x), dim=1) else: x_in = x x = self.input(x_in) @@ -30,7 +30,7 @@ class GinkaModel(nn.Module): # 检查显存占用 if __name__ == "__main__": - input = torch.randn((1, 32, 13, 13)).cuda() + input = torch.randn((1, 32, 32, 32)).cuda() # 初始化模型 model = GinkaModel().cuda() @@ -38,7 +38,7 @@ if __name__ == "__main__": print_memory("初始化后") # 前向传播 - output = model(input, 1) + output, _ = model(input, 1, True) print_memory("前向传播后") diff --git a/ginka/train_wgan.py b/ginka/train_wgan.py index a6c2f6d..06633ce 100644 --- a/ginka/train_wgan.py +++ b/ginka/train_wgan.py @@ -49,12 +49,12 @@ def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch. def gen_total(gen, input, progress_detach=True, result_detach=False, random=False) -> torch.Tensor: if progress_detach: fake1, x_in = gen(input.detach(), 1, random) - fake2, _ = gen(F.softmax(fake1.detach()), 2) - fake3, _ = gen(F.softmax(fake2.detach()), 3) + fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2) + fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3) else: fake1, x_in = gen(input, 1, random) - fake2, _ = gen(F.softmax(fake1), 2) - fake3, _ = gen(F.softmax(fake2), 3) + fake2, _ = gen(F.softmax(fake1, dim=1), 2) + fake3, _ = gen(F.softmax(fake2, dim=1), 3) if result_detach: return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach() else: @@ -69,7 +69,6 @@ def train(): g_steps = 1 # 训练阶段 train_stage = 1 - last_stage = False mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练 stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程 @@ -83,7 +82,6 @@ def train(): dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True) optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9)) - optimizer_head = optim.Adam(ginka_head.parameters(), lr=1e-4, betas=(0.0, 0.9)) optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9)) # scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs) @@ -117,9 +115,6 @@ def train(): if data_ginka.get("stage") is not None: train_stage = data_ginka["stage"] - if data_ginka.get("last_stage") is not None: - last_stage = data_ginka["last_stage"] - if args.load_optim: if data_ginka.get("optim_state") is not None: optimizer_ginka.load_state_dict(data_ginka["optim_state"]) @@ -149,22 +144,11 @@ def train(): for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch] - if train_stage == 4: - # 最后一个阶段训练输入头 - count = 5 if stage_epoch <= 20 else 2 - for _ in range(count): - optimizer_head.zero_grad() - output = F.softmax(ginka_head(masked1), dim=1) - loss_head = criterion.generator_input_head_loss(output) - loss_head.backward() - optimizer_head.step() - # ---------- 训练判别器 for _ in range(c_steps): # 生成假样本 optimizer_minamo.zero_grad() optimizer_ginka.zero_grad() - optimizer_head.zero_grad() with torch.no_grad(): if train_stage == 1 or train_stage == 2: @@ -193,7 +177,6 @@ def train(): for _ in range(g_steps): optimizer_minamo.zero_grad() optimizer_ginka.zero_grad() - optimizer_head.zero_grad() if train_stage == 1 or train_stage == 2: fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False) @@ -210,10 +193,10 @@ def train(): loss_ce_total += loss_ce.detach() elif train_stage == 3 or train_stage == 4: - fake1, fake2, fake3, x_in = gen_total(ginka, input, True, False) + fake1, fake2, fake3, x_in = gen_total(ginka, masked1, True, False, train_stage == 4) if train_stage == 3: - loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, input) + loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1) else: loss_g1 = criterion.generator_loss_total(minamo, 1, fake1) loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1) @@ -221,7 +204,7 @@ def train(): if train_stage == 4: loss_head = criterion.generator_input_head_loss(x_in) - loss_head.backward() + loss_head.backward(retain_graph=True) loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0 loss_g.backward() @@ -239,40 +222,26 @@ def train(): f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}" ) - if avg_loss_ce < 1.0: + if avg_loss_ce < 0.5: low_loss_epochs += 1 else: low_loss_epochs = 0 # 训练流程控制 - if low_loss_epochs >= 3 and train_stage == 1 and stage_epoch >= args.curr_epoch: + if train_stage >= 2: + train_stage += 1 + + if train_stage == 5: + train_stage = 2 + + if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= args.curr_epoch: if mask_ratio >= 0.9: train_stage = 2 mask_ratio += 0.2 mask_ratio = min(mask_ratio, 0.9) low_loss_epochs = 0 stage_epoch = 0 - - if (train_stage == 3 or train_stage == 2) and not last_stage: - if stage_epoch >= 25: - train_stage += 1 - stage_epoch = 0 - - if train_stage == 4: - last_stage = True - - if train_stage >= 3 or last_stage: - # 第三阶段后交叉熵损失不再应该生效 - mask_ratio = 1.0 - - if last_stage: - mask_ratio = 1.0 - if train_stage == 2 and stage_epoch % 5 == 0: - train_stage = 4 - - if train_stage == 4 and stage_epoch % 5 == 1: - train_stage = 2 stage_epoch += 1 @@ -305,7 +274,6 @@ def train(): "stage": train_stage, "mask_ratio": mask_ratio, "stage_epoch": stage_epoch, - "last_stage": last_stage }, f"result/wgan/ginka-{epoch + 1}.pth") torch.save({ "model_state": minamo.state_dict(), diff --git a/minamo/model/model.py b/minamo/model/model.py index 209f59c..702cd64 100644 --- a/minamo/model/model.py +++ b/minamo/model/model.py @@ -2,9 +2,14 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import spectral_norm +from torch_geometric.nn import global_max_pool, GCNConv, global_mean_pool from .vision import MinamoVisionModel from .topo import MinamoTopoModel from shared.constant import VISION_WEIGHT, TOPO_WEIGHT +from shared.graph import batch_convert_soft_map_to_graph + +def print_memory(tag=""): + print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") class MinamoModel(nn.Module): def __init__(self, tile_types=32): @@ -19,20 +24,51 @@ class MinamoModel(nn.Module): topo_feat = self.topo_model(graph) return vision_feat, topo_feat + +class CNNHead(nn.Module): + def __init__(self, in_ch, out_dim): + super().__init__() + self.cnn = nn.Sequential( + spectral_norm(nn.Conv2d(in_ch, in_ch, 3)), + nn.LeakyReLU(0.2), + + nn.AdaptiveMaxPool2d((2, 2)) + ) + self.fc = nn.Sequential( + spectral_norm(nn.Linear(in_ch*2*2, out_dim)) + ) + + def forward(self, x): + x = self.cnn(x) + B, C, H, W = x.shape + x = x.view(B, -1) + x = self.fc(x) + return x -class MinamoScoreHead(nn.Module): +class GCNHead(nn.Module): def __init__(self, in_dim, out_dim): super().__init__() - self.vision_fc = nn.Sequential( - spectral_norm(nn.Linear(in_dim, out_dim)), - ) - self.topo_fc = nn.Sequential( + self.gcn = GCNConv(in_dim, in_dim) + self.fc = nn.Sequential( spectral_norm(nn.Linear(in_dim, out_dim)) ) - def forward(self, vis_feat, topo_feat): - vis_score = self.vision_fc(vis_feat) - topo_score = self.topo_fc(topo_feat) + def forward(self, x, graph): + x = self.gcn(x, graph.edge_index) + x = F.leaky_relu(x, 0.2) + x = global_max_pool(x, graph.batch) + x = self.fc(x) + return x + +class MinamoScoreHead(nn.Module): + def __init__(self, vision_dim, topo_dim, out_dim): + super().__init__() + self.vision_head = CNNHead(vision_dim, out_dim) + self.topo_head = GCNHead(topo_dim, out_dim) + + def forward(self, vis, topo, graph): + vis_score = self.vision_head(vis) + topo_score = self.topo_head(topo, graph) return vis_score, topo_score class MinamoScoreModule(nn.Module): @@ -41,20 +77,41 @@ class MinamoScoreModule(nn.Module): self.topo_model = MinamoTopoModel(tile_types) self.vision_model = MinamoVisionModel(tile_types) # 输出层 - self.head1 = MinamoScoreHead(512, 1) - self.head2 = MinamoScoreHead(512, 1) - self.head3 = MinamoScoreHead(512, 1) + self.head1 = MinamoScoreHead(512, 512, 1) + self.head2 = MinamoScoreHead(512, 512, 1) + self.head3 = MinamoScoreHead(512, 512, 1) def forward(self, map, graph, stage): - vision_feat = self.vision_model(map) - topo_feat = self.topo_model(graph) + vision = self.vision_model(map) + topo = self.topo_model(graph) if stage == 1: - vision_score, topo_score = self.head1(vision_feat, topo_feat) + vision_score, topo_score = self.head1(vision, topo, graph) elif stage == 2: - vision_score, topo_score = self.head2(vision_feat, topo_feat) + vision_score, topo_score = self.head2(vision, topo, graph) elif stage == 3: - vision_score, topo_score = self.head3(vision_feat, topo_feat) + vision_score, topo_score = self.head3(vision, topo, graph) else: raise RuntimeError("Unknown critic stage.") score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score return score, vision_score, topo_score + +# 检查显存占用 +if __name__ == "__main__": + input = torch.randn((1, 32, 13, 13)).cuda() + + # 初始化模型 + model = MinamoScoreModule().cuda() + + print_memory("初始化后") + + # 前向传播 + output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1) + + print_memory("前向传播后") + + print(f"输入形状: feat={input.shape}") + print(f"输出形状: output={output.shape}") + print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}") + print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}") + print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}") + print(f"Total parameters: {sum(p.numel() for p in model.parameters())}") diff --git a/minamo/model/topo.py b/minamo/model/topo.py index a16c54f..967f08f 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -2,12 +2,12 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.utils import spectral_norm -from torch_geometric.nn import GATConv, global_max_pool, GCNConv, global_mean_pool +from torch_geometric.nn import GATConv from torch_geometric.data import Data class MinamoTopoModel(nn.Module): def __init__( - self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512, feat_dim=512 + self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512 ): super().__init__() # 传入 softmax 概率值,直接映射 @@ -20,15 +20,6 @@ class MinamoTopoModel(nn.Module): self.conv2 = GATConv(hidden_dim*8, hidden_dim, heads=8) self.conv3 = GATConv(hidden_dim*8, out_dim, heads=1) - # self.norm1 = nn.LayerNorm(hidden_dim*8) - # self.norm2 = nn.LayerNorm(hidden_dim*8) - # self.norm3 = nn.LayerNorm(out_dim) - - self.fc = nn.Sequential( - spectral_norm(nn.Linear(out_dim, feat_dim)), - nn.LeakyReLU(0.2) - ) - def forward(self, graph: Data): x = self.input_proj(graph.x) @@ -41,10 +32,5 @@ class MinamoTopoModel(nn.Module): x = self.conv3(x, graph.edge_index) x = F.leaky_relu(x, 0.2) - # 池化 - x = global_mean_pool(x, graph.batch) - - topo_vec = self.fc(x) - - return topo_vec + return x \ No newline at end of file diff --git a/minamo/model/vision.py b/minamo/model/vision.py index e52a03e..465760c 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from torch.nn.utils import spectral_norm class MinamoVisionModel(nn.Module): - def __init__(self, in_ch=32, out_dim=512): + def __init__(self, in_ch=32, out_ch=512): super().__init__() self.conv = nn.Sequential( spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11 @@ -13,18 +13,10 @@ class MinamoVisionModel(nn.Module): spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #9*9 nn.LeakyReLU(0.2), - spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 7*7 + spectral_norm(nn.Conv2d(in_ch*4, out_ch, 3)), # 7*7 nn.LeakyReLU(0.2), - - nn.AdaptiveAvgPool2d(2) - ) - self.fc = nn.Sequential( - spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)), - nn.LeakyReLU(0.2) ) def forward(self, x): x = self.conv(x) - x = x.view(x.size(0), -1) - x = self.fc(x) return x diff --git a/train.sh b/train.sh index 152c69e..289e32e 100644 --- a/train.sh +++ b/train.sh @@ -1,4 +1,4 @@ # 从头训练 -python3 -u -m ginka.train_wgan >> output.log +python3 -u -m ginka.train_wgan --epochs 300 >> output.log # 接续训练 -python3 -u -m ginka.train_wgan --resume true --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log +python3 -u -m ginka.train_wgan --resume true --epochs 300 --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log