perf: 优化训练流程,改进判别器与损失值

This commit is contained in:
unanmed 2025-04-27 14:39:03 +08:00
parent a94b07bda8
commit 5be9decf95
10 changed files with 106 additions and 137 deletions

View File

@ -1,7 +0,0 @@
i=$1
while true
do
sh gan.sh "$i"
i=$((i+1))
echo "$i 次循环完成"
done

View File

@ -1,7 +0,0 @@
start=$1
end=$2
for ((i=start; i<=end; i=i+1))
do
sh gan.sh "$i"
echo "$i 次循环完成"
done

17
gan.sh
View File

@ -1,17 +0,0 @@
# 训练部分
python3 -m minamo.train --epochs 10 --resume true
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
python3 -m minamo.train --epochs 10 --resume true
python3 -m ginka.train --epochs 30 --resume true
python3 -m ginka.validate
# 训练完毕,处理数据
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
cd data
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:30
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json"
pnpm review "../minamo-dataset.json" "../datasets/minamo-dataset-merged.json"
pnpm review "../minamo-eval.json" "../datasets/minamo-eval-merged.json"
cd ..

View File

@ -325,7 +325,7 @@ def js_divergence(p, q, eps=1e-6, softmax=False):
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
return torch.clamp(0.5 * (kl_pm + kl_qm), max=10)
return torch.log1p(0.5 * (kl_pm + kl_qm))
def immutable_penalty_loss(
pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
@ -334,8 +334,8 @@ def immutable_penalty_loss(
惩罚模型修改不可更改区域的损失
Args:
input: 模型输出 [B, C, H, W]概率分布 (softmax )
target: 原始输入图 [B, C, H, W]概率分布 (softmax )
input: 模型输出 [B, C, H, W]概率分布 (softmax )
target: 原始输入图 [B, C, H, W]概率分布 (softmax )
modifiable_classes: 允许被修改的类别列表
"""
not_allowed = get_not_allowed(modifiable_classes, include_illegal=True)
@ -344,13 +344,10 @@ def immutable_penalty_loss(
target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1)
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
target_mask = torch.log(target_mask + 1e-6) # 转换为 log 概率分布
input_mask = torch.log(input_mask + 1e-6) # 转换为 log 概率分布
# 差异区域(模型试图改变的地方)
penalty = F.kl_div(input_mask, target_mask, reduction='batchmean', log_target=True)
penalty = F.cross_entropy(input_mask, target_mask)
return torch.clamp(penalty, max=1)
return penalty
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.5, 10, 0.2, 0.2, 0.2]):
@ -420,7 +417,7 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio) # 蒙版越大,交叉熵损失权重越小
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
fake_a, fake_b = fake.chunk(2, dim=0)
@ -471,7 +468,7 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
fake_a, fake_b = fake.chunk(2, dim=0)
@ -496,7 +493,7 @@ class WGANGinkaLoss:
losses = [
input_head_illegal_loss(probs),
input_head_wall_loss(probs),
-js_divergence(probs_a, probs_b, softmax=False) * 0.2
-js_divergence(probs_a, probs_b, softmax=False) * 0.3
]
return sum(losses)

View File

@ -20,7 +20,7 @@ class GinkaModel(nn.Module):
def forward(self, x, stage, random=False):
if random:
x_in = F.softmax(self.head(x))
x_in = F.softmax(self.head(x), dim=1)
else:
x_in = x
x = self.input(x_in)
@ -30,7 +30,7 @@ class GinkaModel(nn.Module):
# 检查显存占用
if __name__ == "__main__":
input = torch.randn((1, 32, 13, 13)).cuda()
input = torch.randn((1, 32, 32, 32)).cuda()
# 初始化模型
model = GinkaModel().cuda()
@ -38,7 +38,7 @@ if __name__ == "__main__":
print_memory("初始化后")
# 前向传播
output = model(input, 1)
output, _ = model(input, 1, True)
print_memory("前向传播后")

View File

@ -49,12 +49,12 @@ def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.
def gen_total(gen, input, progress_detach=True, result_detach=False, random=False) -> torch.Tensor:
if progress_detach:
fake1, x_in = gen(input.detach(), 1, random)
fake2, _ = gen(F.softmax(fake1.detach()), 2)
fake3, _ = gen(F.softmax(fake2.detach()), 3)
fake2, _ = gen(F.softmax(fake1.detach(), dim=1), 2)
fake3, _ = gen(F.softmax(fake2.detach(), dim=1), 3)
else:
fake1, x_in = gen(input, 1, random)
fake2, _ = gen(F.softmax(fake1), 2)
fake3, _ = gen(F.softmax(fake2), 3)
fake2, _ = gen(F.softmax(fake1, dim=1), 2)
fake3, _ = gen(F.softmax(fake2, dim=1), 3)
if result_detach:
return fake1.detach(), fake2.detach(), fake3.detach(), x_in.detach()
else:
@ -69,7 +69,6 @@ def train():
g_steps = 1
# 训练阶段
train_stage = 1
last_stage = False
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
@ -83,7 +82,6 @@ def train():
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_head = optim.Adam(ginka_head.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
@ -117,9 +115,6 @@ def train():
if data_ginka.get("stage") is not None:
train_stage = data_ginka["stage"]
if data_ginka.get("last_stage") is not None:
last_stage = data_ginka["last_stage"]
if args.load_optim:
if data_ginka.get("optim_state") is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
@ -149,22 +144,11 @@ def train():
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
if train_stage == 4:
# 最后一个阶段训练输入头
count = 5 if stage_epoch <= 20 else 2
for _ in range(count):
optimizer_head.zero_grad()
output = F.softmax(ginka_head(masked1), dim=1)
loss_head = criterion.generator_input_head_loss(output)
loss_head.backward()
optimizer_head.step()
# ---------- 训练判别器
for _ in range(c_steps):
# 生成假样本
optimizer_minamo.zero_grad()
optimizer_ginka.zero_grad()
optimizer_head.zero_grad()
with torch.no_grad():
if train_stage == 1 or train_stage == 2:
@ -193,7 +177,6 @@ def train():
for _ in range(g_steps):
optimizer_minamo.zero_grad()
optimizer_ginka.zero_grad()
optimizer_head.zero_grad()
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False)
@ -210,10 +193,10 @@ def train():
loss_ce_total += loss_ce.detach()
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3, x_in = gen_total(ginka, input, True, False)
fake1, fake2, fake3, x_in = gen_total(ginka, masked1, True, False, train_stage == 4)
if train_stage == 3:
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, input)
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1)
else:
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
@ -221,7 +204,7 @@ def train():
if train_stage == 4:
loss_head = criterion.generator_input_head_loss(x_in)
loss_head.backward()
loss_head.backward(retain_graph=True)
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
loss_g.backward()
@ -239,14 +222,20 @@ def train():
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
)
if avg_loss_ce < 1.0:
if avg_loss_ce < 0.5:
low_loss_epochs += 1
else:
low_loss_epochs = 0
# 训练流程控制
if low_loss_epochs >= 3 and train_stage == 1 and stage_epoch >= args.curr_epoch:
if train_stage >= 2:
train_stage += 1
if train_stage == 5:
train_stage = 2
if low_loss_epochs >= 5 and train_stage == 1 and stage_epoch >= args.curr_epoch:
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.2
@ -254,26 +243,6 @@ def train():
low_loss_epochs = 0
stage_epoch = 0
if (train_stage == 3 or train_stage == 2) and not last_stage:
if stage_epoch >= 25:
train_stage += 1
stage_epoch = 0
if train_stage == 4:
last_stage = True
if train_stage >= 3 or last_stage:
# 第三阶段后交叉熵损失不再应该生效
mask_ratio = 1.0
if last_stage:
mask_ratio = 1.0
if train_stage == 2 and stage_epoch % 5 == 0:
train_stage = 4
if train_stage == 4 and stage_epoch % 5 == 1:
train_stage = 2
stage_epoch += 1
dataset.train_stage = train_stage
@ -305,7 +274,6 @@ def train():
"stage": train_stage,
"mask_ratio": mask_ratio,
"stage_epoch": stage_epoch,
"last_stage": last_stage
}, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({
"model_state": minamo.state_dict(),

View File

@ -2,9 +2,14 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch_geometric.nn import global_max_pool, GCNConv, global_mean_pool
from .vision import MinamoVisionModel
from .topo import MinamoTopoModel
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
from shared.graph import batch_convert_soft_map_to_graph
def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
class MinamoModel(nn.Module):
def __init__(self, tile_types=32):
@ -20,19 +25,50 @@ class MinamoModel(nn.Module):
return vision_feat, topo_feat
class MinamoScoreHead(nn.Module):
class CNNHead(nn.Module):
def __init__(self, in_ch, out_dim):
super().__init__()
self.cnn = nn.Sequential(
spectral_norm(nn.Conv2d(in_ch, in_ch, 3)),
nn.LeakyReLU(0.2),
nn.AdaptiveMaxPool2d((2, 2))
)
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_ch*2*2, out_dim))
)
def forward(self, x):
x = self.cnn(x)
B, C, H, W = x.shape
x = x.view(B, -1)
x = self.fc(x)
return x
class GCNHead(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.vision_fc = nn.Sequential(
spectral_norm(nn.Linear(in_dim, out_dim)),
)
self.topo_fc = nn.Sequential(
self.gcn = GCNConv(in_dim, in_dim)
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_dim, out_dim))
)
def forward(self, vis_feat, topo_feat):
vis_score = self.vision_fc(vis_feat)
topo_score = self.topo_fc(topo_feat)
def forward(self, x, graph):
x = self.gcn(x, graph.edge_index)
x = F.leaky_relu(x, 0.2)
x = global_max_pool(x, graph.batch)
x = self.fc(x)
return x
class MinamoScoreHead(nn.Module):
def __init__(self, vision_dim, topo_dim, out_dim):
super().__init__()
self.vision_head = CNNHead(vision_dim, out_dim)
self.topo_head = GCNHead(topo_dim, out_dim)
def forward(self, vis, topo, graph):
vis_score = self.vision_head(vis)
topo_score = self.topo_head(topo, graph)
return vis_score, topo_score
class MinamoScoreModule(nn.Module):
@ -41,20 +77,41 @@ class MinamoScoreModule(nn.Module):
self.topo_model = MinamoTopoModel(tile_types)
self.vision_model = MinamoVisionModel(tile_types)
# 输出层
self.head1 = MinamoScoreHead(512, 1)
self.head2 = MinamoScoreHead(512, 1)
self.head3 = MinamoScoreHead(512, 1)
self.head1 = MinamoScoreHead(512, 512, 1)
self.head2 = MinamoScoreHead(512, 512, 1)
self.head3 = MinamoScoreHead(512, 512, 1)
def forward(self, map, graph, stage):
vision_feat = self.vision_model(map)
topo_feat = self.topo_model(graph)
vision = self.vision_model(map)
topo = self.topo_model(graph)
if stage == 1:
vision_score, topo_score = self.head1(vision_feat, topo_feat)
vision_score, topo_score = self.head1(vision, topo, graph)
elif stage == 2:
vision_score, topo_score = self.head2(vision_feat, topo_feat)
vision_score, topo_score = self.head2(vision, topo, graph)
elif stage == 3:
vision_score, topo_score = self.head3(vision_feat, topo_feat)
vision_score, topo_score = self.head3(vision, topo, graph)
else:
raise RuntimeError("Unknown critic stage.")
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
return score, vision_score, topo_score
# 检查显存占用
if __name__ == "__main__":
input = torch.randn((1, 32, 13, 13)).cuda()
# 初始化模型
model = MinamoScoreModule().cuda()
print_memory("初始化后")
# 前向传播
output, _, _ = model(input, batch_convert_soft_map_to_graph(input), 1)
print_memory("前向传播后")
print(f"输入形状: feat={input.shape}")
print(f"输出形状: output={output.shape}")
print(f"Topo parameters: {sum(p.numel() for p in model.topo_model.parameters())}")
print(f"Vision parameters: {sum(p.numel() for p in model.vision_model.parameters())}")
print(f"Head parameters: {sum(p.numel() for p in model.head1.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -2,12 +2,12 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch_geometric.nn import GATConv, global_max_pool, GCNConv, global_mean_pool
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
class MinamoTopoModel(nn.Module):
def __init__(
self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512, feat_dim=512
self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512
):
super().__init__()
# 传入 softmax 概率值,直接映射
@ -20,15 +20,6 @@ class MinamoTopoModel(nn.Module):
self.conv2 = GATConv(hidden_dim*8, hidden_dim, heads=8)
self.conv3 = GATConv(hidden_dim*8, out_dim, heads=1)
# self.norm1 = nn.LayerNorm(hidden_dim*8)
# self.norm2 = nn.LayerNorm(hidden_dim*8)
# self.norm3 = nn.LayerNorm(out_dim)
self.fc = nn.Sequential(
spectral_norm(nn.Linear(out_dim, feat_dim)),
nn.LeakyReLU(0.2)
)
def forward(self, graph: Data):
x = self.input_proj(graph.x)
@ -41,10 +32,5 @@ class MinamoTopoModel(nn.Module):
x = self.conv3(x, graph.edge_index)
x = F.leaky_relu(x, 0.2)
# 池化
x = global_mean_pool(x, graph.batch)
topo_vec = self.fc(x)
return topo_vec
return x

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from torch.nn.utils import spectral_norm
class MinamoVisionModel(nn.Module):
def __init__(self, in_ch=32, out_dim=512):
def __init__(self, in_ch=32, out_ch=512):
super().__init__()
self.conv = nn.Sequential(
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3)), # 11*11
@ -13,18 +13,10 @@ class MinamoVisionModel(nn.Module):
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #9*9
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 7*7
spectral_norm(nn.Conv2d(in_ch*4, out_ch, 3)), # 7*7
nn.LeakyReLU(0.2),
nn.AdaptiveAvgPool2d(2)
)
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_ch*8*2*2, out_dim)),
nn.LeakyReLU(0.2)
)
def forward(self, x):
x = self.conv(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x

View File

@ -1,4 +1,4 @@
# 从头训练
python3 -u -m ginka.train_wgan >> output.log
python3 -u -m ginka.train_wgan --epochs 300 >> output.log
# 接续训练
python3 -u -m ginka.train_wgan --resume true --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log
python3 -u -m ginka.train_wgan --resume true --epochs 300 --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log