feat: 改进输出头和特征融合

This commit is contained in:
unanmed 2025-04-19 15:05:26 +08:00
parent ef0b7ffba2
commit 87016c67e8
9 changed files with 218 additions and 146 deletions

View File

@ -83,12 +83,12 @@ class GinkaWGANDataset(Dataset):
self.mask_ratio1 = 0.1
self.mask_ratio2 = 0.1
self.mask_ratio3 = 0.1
self.random_ratio = 0.0
def __len__(self):
return len(self.data)
def handle_stage1(self, target):
# 课程学习第一阶段,蒙版填充
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
@ -96,33 +96,30 @@ class GinkaWGANDataset(Dataset):
return removed1, masked1, removed2, masked2, removed3, masked3
def handle_stage2(self, target):
# 课程学习第二阶段,完全随机蒙版
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
if self.random_ratio > 0:
rd = random.uniform(0, self.random_ratio)
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
masked2 = random_smooth_onehot(masked2, min_main=1 - rd, max_main=1.0, epsilon=rd)
masked3 = random_smooth_onehot(masked3, min_main=1 - rd, max_main=1.0, epsilon=rd)
return removed1, masked1, removed2, masked2, removed3, masked3
def handle_stage3(self, target):
# 第三阶段,联合生成,输入随机蒙版
rd = random.uniform(0, self.random_ratio)
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
def handle_stage4(self, target):
input1 = torch.rand((32, 13, 13))
# 第四阶段,与第二阶段交替进行,完全随机输入
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
return removed1, input1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
rand = torch.rand(32, 32, 32, device=target.device)
return removed1, rand, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
def __getitem__(self, idx):
item = self.data[idx]
@ -137,19 +134,9 @@ class GinkaWGANDataset(Dataset):
elif self.train_stage == 3:
return self.handle_stage3(target)
elif self.train_stage == 4:
self.mask_ratio1 = self.mask_ratio2 = self.mask_ratio3 = random.uniform(0, 0.9)
self.random_ratio = 0.2
mode = random.choices([1, 2, 3, 4], weights=[0.2, 0.2, 0.2, 0.4])
if mode == 1:
return self.handle_stage1(target)
elif mode == 2:
return self.handle_stage2(target)
elif mode == 3:
return self.handle_stage3(target)
else:
return self.handle_stage4(target)
return self.handle_stage4(target)
raise RuntimeError(f"Invalid train stage: {self.train_stage}")

64
ginka/model/common.py Normal file
View File

@ -0,0 +1,64 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import grid
class DoubleConvBlock(nn.Module):
def __init__(self, feats: tuple[int, int, int]):
super().__init__()
self.cnn = nn.Sequential(
nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(feats[1]),
nn.ELU(),
nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(feats[2]),
nn.ELU(),
)
def forward(self, x):
x = self.cnn(x)
return x
class GCNBlock(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
super().__init__()
self.conv1 = GCNConv(in_ch, hidden_ch)
self.conv2 = GCNConv(hidden_ch, out_ch)
self.norm1 = nn.LayerNorm(hidden_ch)
self.norm2 = nn.LayerNorm(out_ch)
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
def forward(self, x):
# x: [B, C, H, W]
B, C, H, W = x.shape
# Reshape to [B * H * W, C]
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
# Construct batched edge index
device = x.device
edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W)
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
# GCN forward
x = self.conv1(x, edge_index)
x = F.elu(self.norm1(x))
x = self.conv2(x, edge_index)
x = F.elu(self.norm2(x))
# Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
return x
def _batch_edge_index(self, B, edge_index, num_nodes_per_batch):
# 批次偏移 edge_index
edge_index = edge_index.clone() # [2, E]
batch_edge_index = []
for i in range(B):
offset = i * num_nodes_per_batch
batch_edge_index.append(edge_index + offset)
return torch.cat(batch_edge_index, dim=1)

View File

@ -1,6 +1,31 @@
import torch
import torch.nn as nn
class RandomInputHead(nn.Module):
def __init__(self, in_size=(32, 32), out_size=(32, 32)):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(1, 32, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(32),
nn.ELU(),
nn.Conv2d(32, 64, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(64),
nn.ELU(),
nn.Conv2d(64, 128, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(128),
nn.ELU(),
)
self.out_conv = nn.Sequential(
nn.Conv2d(128, 32, 1),
)
def forward(self, x):
x = self.conv(x)
x = self.out_conv(x)
return x
class GinkaInput(nn.Module):
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
super().__init__()

View File

@ -64,7 +64,7 @@ def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11]):
return loss_unallowed
def inner_constraint_loss(pred: torch.Tensor, allowed=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12]):
def inner_constraint_loss(pred: torch.Tensor, allowed=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13]):
"""限定内部允许出现的图块种类
Args:
@ -235,6 +235,21 @@ def adaptive_count_loss(
return total_loss
def input_head_illegal_loss(input_map, allowed_classes=(0, 1)):
C = input_map.shape[1]
mask = torch.ones(C, device=input_map.device)
mask[list(allowed_classes)] = 0 # 屏蔽允许的类别,其余为 1
illegal_class_penalty = (input_map * mask.view(1, -1, 1, 1)).sum() / input_map.numel()
return illegal_class_penalty
def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1):
wall_prob = input_map[:, wall_class] # [B, H, W]
wall_ratio = wall_prob.mean() # 计算平均墙体占比
wall_penalty = torch.clamp(wall_ratio - max_wall_ratio, min=0.0) # 超过则惩罚
return wall_penalty
class GinkaLoss(nn.Module):
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]):
"""Ginka Model 损失函数部分
@ -310,7 +325,7 @@ def js_divergence(p, q, eps=1e-6, softmax=False):
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
return 0.5 * (kl_pm + kl_qm)
return torch.clamp(0.5 * (kl_pm + kl_qm), max=10)
def immutable_penalty_loss(
pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
@ -322,7 +337,6 @@ def immutable_penalty_loss(
input: 模型输出 [B, C, H, W]概率分布 (softmax )
target: 原始输入图 [B, C, H, W]概率分布 (softmax )
modifiable_classes: 允许被修改的类别列表
penalty_weight: 对非允许修改区域的惩罚系数
"""
not_allowed = get_not_allowed(modifiable_classes, include_illegal=True)
input_mask = pred[:, not_allowed, :, :]
@ -330,14 +344,17 @@ def immutable_penalty_loss(
target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1)
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
# 差异区域(模型试图改变的地方)
penalty = F.l1_loss(input_mask, target_mask)
target_mask = torch.log(target_mask + 1e-6) # 转换为 log 概率分布
input_mask = torch.log(input_mask + 1e-6) # 转换为 log 概率分布
return penalty
# 差异区域(模型试图改变的地方)
penalty = F.kl_div(input_mask, target_mask, reduction='batchmean', log_target=True)
return torch.clamp(penalty, max=1)
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.01]):
# weight: 判别器损失,L1 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
def __init__(self, lambda_gp=100, weight=[1, 0.5, 10, 0.2, 0.2, 0.2]):
# weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
@ -402,18 +419,18 @@ class WGANGinkaLoss:
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
ce_loss = F.cross_entropy(fake, real)
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio)
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
fake_a, fake_b = fake.chunk(2, dim=0)
# fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
ce_loss * self.weight[1], # 蒙版越大,交叉熵损失权重越小
immutable_loss * self.weight[2],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
]
if stage == 1:
@ -433,12 +450,12 @@ class WGANGinkaLoss:
minamo_loss = -torch.mean(fake_scores)
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
fake_a, fake_b = fake.chunk(2, dim=0)
# fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
]
if stage == 1:
@ -457,13 +474,13 @@ class WGANGinkaLoss:
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
fake_a, fake_b = fake.chunk(2, dim=0)
# fake_a, fake_b = fake.chunk(2, dim=0)
losses = [
minamo_loss * self.weight[0],
immutable_loss * self.weight[2],
constraint_loss * self.weight[3],
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
]
if stage == 1:
@ -472,3 +489,14 @@ class WGANGinkaLoss:
losses.append(entrance_loss * self.weight[4])
return sum(losses)
def generator_input_head_loss(self, probs: torch.Tensor) -> torch.Tensor:
probs_a, probs_b = probs.chunk(2, dim=0)
losses = [
input_head_illegal_loss(probs),
input_head_wall_loss(probs),
-js_divergence(probs_a, probs_b, softmax=False) * 0.2
]
return sum(losses)

View File

@ -1,25 +1,23 @@
import torch
import torch.nn as nn
from .common import GCNBlock, DoubleConvBlock
class StageHead(nn.Module):
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
super().__init__()
self.head = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(in_ch),
nn.ELU(),
nn.Conv2d(in_ch, in_ch, 1),
nn.InstanceNorm2d(in_ch),
nn.ELU(),
)
self.cnn_head = DoubleConvBlock([in_ch, in_ch*2, in_ch])
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
self.pool = nn.Sequential(
nn.AdaptiveMaxPool2d(out_size),
nn.Conv2d(in_ch, out_ch, 1)
)
def forward(self, x):
x = self.head(x)
x_cnn = self.cnn_head(x)
x_gcn = self.gcn_head(x)
x = torch.cat([x_cnn, x_gcn], dim=1)
x = self.fusion(x)
x = self.pool(x)
return x

View File

@ -1,9 +1,8 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.utils import grid
from shared.attention import ChannelAttention
from .common import GCNBlock, DoubleConvBlock
class GinkaTransformerEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
@ -35,7 +34,7 @@ class GinkaTransformerEncoder(nn.Module):
return x
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, atte=True):
def __init__(self, in_ch, out_ch, attn=True):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
@ -44,63 +43,17 @@ class ConvBlock(nn.Module):
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
)
if atte:
if attn:
self.conv.append(ChannelAttention(out_ch))
self.conv.append(nn.ELU())
def forward(self, x):
return self.conv(x)
class GCNBlock(nn.Module):
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
super().__init__()
self.conv1 = GCNConv(in_ch, hidden_ch)
self.conv2 = GCNConv(hidden_ch, out_ch)
self.norm1 = nn.LayerNorm(hidden_ch)
self.norm2 = nn.LayerNorm(out_ch)
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
def forward(self, x):
# x: [B, C, H, W]
B, C, H, W = x.shape
# Reshape to [B * H * W, C]
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
# Construct batched edge index
device = x.device
edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W)
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
# GCN forward
x = self.conv1(x, edge_index)
x = F.elu(self.norm1(x))
x = self.conv2(x, edge_index)
x = F.elu(self.norm2(x))
# Reshape back to [B, C, H, W]
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
return x
def _batch_edge_index(self, B, edge_index, num_nodes_per_batch):
# 批次偏移 edge_index
edge_index = edge_index.clone() # [2, E]
batch_edge_index = []
for i in range(B):
offset = i * num_nodes_per_batch
batch_edge_index.append(edge_index + offset)
return torch.cat(batch_edge_index, dim=1)
class FusionModule(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1),
nn.InstanceNorm2d(out_ch),
nn.ELU()
)
self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)

View File

@ -4,17 +4,16 @@ import sys
from datetime import datetime
import torch
import torch.optim as optim
import torch.nn.functional as F
import cv2
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .model.model import GinkaModel
from .dataset import GinkaWGANDataset
from .model.loss import WGANGinkaLoss
from .model.input import RandomInputHead
from minamo.model.model import MinamoScoreModule
from minamo.model.similarity import MinamoSimilarityModel
from shared.graph import batch_convert_soft_map_to_graph
from shared.image import matrix_to_image_cv
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
BATCH_SIZE = 16
@ -67,17 +66,15 @@ def train():
c_steps = 5
g_steps = 1
# 1 代表课程学习阶段2 代表课程学习后,逐渐转为联合学习的阶段
# 3 代表课程学习后的联合遮挡学习阶段4 代表最后随机输入的联合学习阶段
# 训练阶段
train_stage = 1
last_stage = False
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
random_ratio = 0
stage3_epoch = 0 # 第三阶段 epoch 数,若干轮后进入第四阶段
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
ginka = GinkaModel()
minamo = MinamoScoreModule()
ginka.to(device)
minamo.to(device)
ginka = GinkaModel().to(device)
ginka_head = RandomInputHead().to(device)
minamo = MinamoScoreModule().to(device)
dataset = GinkaWGANDataset(args.train, device)
dataset_val = GinkaWGANDataset(args.validate, device)
@ -85,6 +82,7 @@ def train():
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_head = optim.Adam(ginka_head.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
@ -112,15 +110,15 @@ def train():
if data_ginka.get("mask_ratio") is not None:
mask_ratio = data_ginka["mask_ratio"]
if data_ginka.get("random_ratio") is not None:
random_ratio = data_ginka["random_ratio"]
if data_ginka.get("stage_epoch3") is not None:
stage3_epoch = data_ginka["stage_epoch3"]
if data_ginka.get("stage_epoch") is not None:
stage_epoch = data_ginka["stage_epoch"]
if data_ginka.get("stage") is not None:
train_stage = data_ginka["stage"]
if data_ginka.get("last_stage") is not None:
last_stage = data_ginka["last_stage"]
if args.load_optim:
if data_ginka.get("optim_state") is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
@ -131,13 +129,11 @@ def train():
dataset.mask_ratio1 = mask_ratio
dataset.mask_ratio2 = mask_ratio
dataset.mask_ratio3 = mask_ratio
dataset.random_ratio = random_ratio
dataset_val.train_stage = train_stage
dataset_val.mask_ratio1 = mask_ratio
dataset_val.mask_ratio2 = mask_ratio
dataset_val.mask_ratio3 = mask_ratio
dataset_val.random_ratio = random_ratio
print("Train from loaded state.")
@ -152,16 +148,34 @@ def train():
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
if train_stage == 4:
# 最后一个阶段训练输入头
count = 5 if stage_epoch <= 20 else 2
for _ in range(count):
optimizer_head.zero_grad()
output = F.softmax(ginka_head(masked1), dim=1)
loss_head = criterion.generator_input_head_loss(output)
loss_head.backward()
optimizer_head.step()
# ---------- 训练判别器
for _ in range(c_steps):
# 生成假样本
optimizer_minamo.zero_grad()
optimizer_ginka.zero_grad()
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
optimizer_head.zero_grad()
with torch.no_grad():
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
elif train_stage == 3:
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
elif train_stage == 4:
input = F.softmax(ginka_head(masked1), dim=1)
fake1, fake2, fake3 = gen_total(ginka, input, True, True)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
@ -183,6 +197,7 @@ def train():
for _ in range(g_steps):
optimizer_minamo.zero_grad()
optimizer_ginka.zero_grad()
optimizer_head.zero_grad()
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False)
@ -199,10 +214,12 @@ def train():
loss_ce_total += loss_ce.detach()
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3 = gen_total(ginka, masked1, True, False)
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
fake1, fake2, fake3 = gen_total(ginka, input, True, False)
if train_stage == 3:
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1)
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, input)
else:
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
@ -221,43 +238,42 @@ def train():
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | R: {random_ratio:.1f}"
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
)
if avg_loss_ce < 0.5:
if avg_loss_ce < 1.0:
low_loss_epochs += 1
else:
low_loss_epochs = 0
if low_loss_epochs >= 3 and train_stage == 2:
if random_ratio >= 0.5:
train_stage = 3
random_ratio += 0.2
random_ratio = min(random_ratio, 0.5)
low_loss_epochs = 0
if low_loss_epochs >= 3 and train_stage == 1:
if mask_ratio >= 0.9:
train_stage = 2
stage_epoch = 0
mask_ratio += 0.2
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
if train_stage == 3:
stage3_epoch += 1
# 十轮足够了
if stage3_epoch >= 10:
train_stage = 4
stage3_epoch = 0
if train_stage == 3 or train_stage == 2:
if stage_epoch >= 25:
train_stage += 1
stage_epoch = 0
if train_stage >= 2:
# 第二阶段后 L1 损失不再应该生效
if train_stage >= 3:
# 第三阶段后交叉熵损失不再应该生效
mask_ratio = 1.0
if last_stage:
if train_stage == 2 and stage_epoch % 5 == 0:
train_stage = 4
if train_stage == 4 and stage_epoch % 5 == 1:
train_stage = 2
stage_epoch += 1
dataset.train_stage = train_stage
dataset_val.train_stage = train_stage
dataset.random_ratio = random_ratio
dataset_val.random_ratio = random_ratio
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
@ -284,8 +300,8 @@ def train():
"g_steps": g_steps,
"stage": train_stage,
"mask_ratio": mask_ratio,
"random_ratio": random_ratio,
"stage3_epoch": stage3_epoch,
"stage_epoch": stage_epoch,
"last_stage": last_stage
}, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({
"model_state": minamo.state_dict(),
@ -300,7 +316,8 @@ def train():
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
elif train_stage == 3 or train_stage == 4:
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
fake1, fake2, fake3 = gen_total(ginka, input, True, True)
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()

BIN
tiles/13.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 419 B

View File

@ -1,4 +1,4 @@
# 从头训练
python3 -u -m ginka.train_wgan >> output.log
# 接续训练
python3 -u -m ginka.train_wgan --resume true --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log
python3 -u -m ginka.train_wgan --resume true --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log