mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: 改进输出头和特征融合
This commit is contained in:
parent
ef0b7ffba2
commit
87016c67e8
@ -83,12 +83,12 @@ class GinkaWGANDataset(Dataset):
|
||||
self.mask_ratio1 = 0.1
|
||||
self.mask_ratio2 = 0.1
|
||||
self.mask_ratio3 = 0.1
|
||||
self.random_ratio = 0.0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def handle_stage1(self, target):
|
||||
# 课程学习第一阶段,蒙版填充
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
|
||||
@ -96,33 +96,30 @@ class GinkaWGANDataset(Dataset):
|
||||
return removed1, masked1, removed2, masked2, removed3, masked3
|
||||
|
||||
def handle_stage2(self, target):
|
||||
# 课程学习第二阶段,完全随机蒙版
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
# 后面两个阶段由于会保留一些类别,所以完全随机遮挡即可
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 1))
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 1))
|
||||
|
||||
if self.random_ratio > 0:
|
||||
rd = random.uniform(0, self.random_ratio)
|
||||
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||
masked2 = random_smooth_onehot(masked2, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||
masked3 = random_smooth_onehot(masked3, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||
|
||||
return removed1, masked1, removed2, masked2, removed3, masked3
|
||||
|
||||
def handle_stage3(self, target):
|
||||
# 第三阶段,联合生成,输入随机蒙版
|
||||
rd = random.uniform(0, self.random_ratio)
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||
masked1 = random_smooth_onehot(masked1, min_main=1 - rd, max_main=1.0, epsilon=rd)
|
||||
return removed1, masked1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
|
||||
|
||||
|
||||
def handle_stage4(self, target):
|
||||
input1 = torch.rand((32, 13, 13))
|
||||
# 第四阶段,与第二阶段交替进行,完全随机输入
|
||||
removed1 = apply_curriculum_remove(target, STAGE1_REMOVE)
|
||||
removed2 = apply_curriculum_remove(target, STAGE2_REMOVE)
|
||||
removed3 = apply_curriculum_remove(target, STAGE3_REMOVE)
|
||||
return removed1, input1, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
|
||||
rand = torch.rand(32, 32, 32, device=target.device)
|
||||
return removed1, rand, removed2, torch.zeros_like(target), removed3, torch.zeros_like(target)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
@ -137,19 +134,9 @@ class GinkaWGANDataset(Dataset):
|
||||
|
||||
elif self.train_stage == 3:
|
||||
return self.handle_stage3(target)
|
||||
|
||||
|
||||
elif self.train_stage == 4:
|
||||
self.mask_ratio1 = self.mask_ratio2 = self.mask_ratio3 = random.uniform(0, 0.9)
|
||||
self.random_ratio = 0.2
|
||||
mode = random.choices([1, 2, 3, 4], weights=[0.2, 0.2, 0.2, 0.4])
|
||||
if mode == 1:
|
||||
return self.handle_stage1(target)
|
||||
elif mode == 2:
|
||||
return self.handle_stage2(target)
|
||||
elif mode == 3:
|
||||
return self.handle_stage3(target)
|
||||
else:
|
||||
return self.handle_stage4(target)
|
||||
return self.handle_stage4(target)
|
||||
|
||||
raise RuntimeError(f"Invalid train stage: {self.train_stage}")
|
||||
|
||||
64
ginka/model/common.py
Normal file
64
ginka/model/common.py
Normal file
@ -0,0 +1,64 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv
|
||||
from torch_geometric.utils import grid
|
||||
|
||||
class DoubleConvBlock(nn.Module):
|
||||
def __init__(self, feats: tuple[int, int, int]):
|
||||
super().__init__()
|
||||
self.cnn = nn.Sequential(
|
||||
nn.Conv2d(feats[0], feats[1], 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(feats[1]),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Conv2d(feats[1], feats[2], 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(feats[2]),
|
||||
nn.ELU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.cnn(x)
|
||||
return x
|
||||
|
||||
class GCNBlock(nn.Module):
|
||||
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.conv1 = GCNConv(in_ch, hidden_ch)
|
||||
self.conv2 = GCNConv(hidden_ch, out_ch)
|
||||
self.norm1 = nn.LayerNorm(hidden_ch)
|
||||
self.norm2 = nn.LayerNorm(out_ch)
|
||||
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C, H, W]
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Reshape to [B * H * W, C]
|
||||
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
|
||||
|
||||
# Construct batched edge index
|
||||
device = x.device
|
||||
edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W)
|
||||
|
||||
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
|
||||
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
|
||||
|
||||
# GCN forward
|
||||
x = self.conv1(x, edge_index)
|
||||
x = F.elu(self.norm1(x))
|
||||
x = self.conv2(x, edge_index)
|
||||
x = F.elu(self.norm2(x))
|
||||
|
||||
# Reshape back to [B, C, H, W]
|
||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
def _batch_edge_index(self, B, edge_index, num_nodes_per_batch):
|
||||
# 批次偏移 edge_index
|
||||
edge_index = edge_index.clone() # [2, E]
|
||||
batch_edge_index = []
|
||||
for i in range(B):
|
||||
offset = i * num_nodes_per_batch
|
||||
batch_edge_index.append(edge_index + offset)
|
||||
return torch.cat(batch_edge_index, dim=1)
|
||||
@ -1,6 +1,31 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class RandomInputHead(nn.Module):
|
||||
def __init__(self, in_size=(32, 32), out_size=(32, 32)):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(1, 32, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(32),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Conv2d(32, 64, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(64),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Conv2d(64, 128, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(128),
|
||||
nn.ELU(),
|
||||
)
|
||||
self.out_conv = nn.Sequential(
|
||||
nn.Conv2d(128, 32, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.out_conv(x)
|
||||
return x
|
||||
|
||||
class GinkaInput(nn.Module):
|
||||
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
|
||||
super().__init__()
|
||||
|
||||
@ -64,7 +64,7 @@ def outer_border_constraint_loss(pred: torch.Tensor, allowed_classes=[1, 11]):
|
||||
|
||||
return loss_unallowed
|
||||
|
||||
def inner_constraint_loss(pred: torch.Tensor, allowed=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12]):
|
||||
def inner_constraint_loss(pred: torch.Tensor, allowed=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13]):
|
||||
"""限定内部允许出现的图块种类
|
||||
|
||||
Args:
|
||||
@ -235,6 +235,21 @@ def adaptive_count_loss(
|
||||
|
||||
return total_loss
|
||||
|
||||
def input_head_illegal_loss(input_map, allowed_classes=(0, 1)):
|
||||
C = input_map.shape[1]
|
||||
mask = torch.ones(C, device=input_map.device)
|
||||
mask[list(allowed_classes)] = 0 # 屏蔽允许的类别,其余为 1
|
||||
illegal_class_penalty = (input_map * mask.view(1, -1, 1, 1)).sum() / input_map.numel()
|
||||
|
||||
return illegal_class_penalty
|
||||
|
||||
def input_head_wall_loss(input_map, max_wall_ratio=0.2, wall_class=1):
|
||||
wall_prob = input_map[:, wall_class] # [B, H, W]
|
||||
wall_ratio = wall_prob.mean() # 计算平均墙体占比
|
||||
wall_penalty = torch.clamp(wall_ratio - max_wall_ratio, min=0.0) # 超过则惩罚
|
||||
|
||||
return wall_penalty
|
||||
|
||||
class GinkaLoss(nn.Module):
|
||||
def __init__(self, minamo: MinamoModel, weight=[0.5, 0.2, 0.1, 0.2]):
|
||||
"""Ginka Model 损失函数部分
|
||||
@ -310,7 +325,7 @@ def js_divergence(p, q, eps=1e-6, softmax=False):
|
||||
kl_pm = F.kl_div(log_p, log_m, reduction='batchmean', log_target=True) # KL(p || m)
|
||||
kl_qm = F.kl_div(log_q, log_m, reduction='batchmean', log_target=True) # KL(q || m)
|
||||
|
||||
return 0.5 * (kl_pm + kl_qm)
|
||||
return torch.clamp(0.5 * (kl_pm + kl_qm), max=10)
|
||||
|
||||
def immutable_penalty_loss(
|
||||
pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
|
||||
@ -322,7 +337,6 @@ def immutable_penalty_loss(
|
||||
input: 模型输出 [B, C, H, W],概率分布 (softmax 后)
|
||||
target: 原始输入图 [B, C, H, W],概率分布 (softmax 后)
|
||||
modifiable_classes: 允许被修改的类别列表
|
||||
penalty_weight: 对非允许修改区域的惩罚系数
|
||||
"""
|
||||
not_allowed = get_not_allowed(modifiable_classes, include_illegal=True)
|
||||
input_mask = pred[:, not_allowed, :, :]
|
||||
@ -330,14 +344,17 @@ def immutable_penalty_loss(
|
||||
target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1)
|
||||
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
|
||||
|
||||
# 差异区域(模型试图改变的地方)
|
||||
penalty = F.l1_loss(input_mask, target_mask)
|
||||
target_mask = torch.log(target_mask + 1e-6) # 转换为 log 概率分布
|
||||
input_mask = torch.log(input_mask + 1e-6) # 转换为 log 概率分布
|
||||
|
||||
return penalty
|
||||
# 差异区域(模型试图改变的地方)
|
||||
penalty = F.kl_div(input_mask, target_mask, reduction='batchmean', log_target=True)
|
||||
|
||||
return torch.clamp(penalty, max=1)
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.4, 25, 0.2, 0.2, 0.01]):
|
||||
# weight: 判别器损失,L1 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.5, 10, 0.2, 0.2, 0.2]):
|
||||
# weight: 判别器损失,CE 损失,不可修改类型损失,图块类型损失,入口存在性损失,多样性损失
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
|
||||
@ -402,18 +419,18 @@ class WGANGinkaLoss:
|
||||
|
||||
fake_scores, _, _ = critic(probs_fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
ce_loss = F.cross_entropy(fake, real)
|
||||
ce_loss = F.cross_entropy(fake, real) * (1 - mask_ratio)
|
||||
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
# fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
ce_loss * self.weight[1] * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
|
||||
ce_loss * self.weight[1], # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
@ -433,12 +450,12 @@ class WGANGinkaLoss:
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
# fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
@ -457,13 +474,13 @@ class WGANGinkaLoss:
|
||||
immutable_loss = immutable_penalty_loss(probs_fake, F.softmax(input, dim=1), STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(probs_fake) + inner_constraint_loss(probs_fake)
|
||||
|
||||
fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
# fake_a, fake_b = fake.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3],
|
||||
-js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
# -js_divergence(fake_a, fake_b, softmax=True) * self.weight[5],
|
||||
]
|
||||
|
||||
if stage == 1:
|
||||
@ -472,3 +489,14 @@ class WGANGinkaLoss:
|
||||
losses.append(entrance_loss * self.weight[4])
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def generator_input_head_loss(self, probs: torch.Tensor) -> torch.Tensor:
|
||||
probs_a, probs_b = probs.chunk(2, dim=0)
|
||||
|
||||
losses = [
|
||||
input_head_illegal_loss(probs),
|
||||
input_head_wall_loss(probs),
|
||||
-js_divergence(probs_a, probs_b, softmax=False) * 0.2
|
||||
]
|
||||
|
||||
return sum(losses)
|
||||
|
||||
@ -1,25 +1,23 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .common import GCNBlock, DoubleConvBlock
|
||||
|
||||
class StageHead(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
|
||||
super().__init__()
|
||||
self.head = nn.Sequential(
|
||||
nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(in_ch),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Conv2d(in_ch, in_ch, 1),
|
||||
nn.InstanceNorm2d(in_ch),
|
||||
nn.ELU(),
|
||||
)
|
||||
self.cnn_head = DoubleConvBlock([in_ch, in_ch*2, in_ch])
|
||||
self.gcn_head = GCNBlock(in_ch, in_ch*2, in_ch, 32, 32)
|
||||
self.fusion = DoubleConvBlock([in_ch*2, in_ch*4, in_ch])
|
||||
self.pool = nn.Sequential(
|
||||
nn.AdaptiveMaxPool2d(out_size),
|
||||
nn.Conv2d(in_ch, out_ch, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.head(x)
|
||||
x_cnn = self.cnn_head(x)
|
||||
x_gcn = self.gcn_head(x)
|
||||
x = torch.cat([x_cnn, x_gcn], dim=1)
|
||||
x = self.fusion(x)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv
|
||||
from torch_geometric.utils import grid
|
||||
from shared.attention import ChannelAttention
|
||||
from .common import GCNBlock, DoubleConvBlock
|
||||
|
||||
class GinkaTransformerEncoder(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
|
||||
@ -35,7 +34,7 @@ class GinkaTransformerEncoder(nn.Module):
|
||||
return x
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, atte=True):
|
||||
def __init__(self, in_ch, out_ch, attn=True):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
@ -44,63 +43,17 @@ class ConvBlock(nn.Module):
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
)
|
||||
if atte:
|
||||
if attn:
|
||||
self.conv.append(ChannelAttention(out_ch))
|
||||
self.conv.append(nn.ELU())
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class GCNBlock(nn.Module):
|
||||
def __init__(self, in_ch, hidden_ch, out_ch, w, h):
|
||||
super().__init__()
|
||||
self.conv1 = GCNConv(in_ch, hidden_ch)
|
||||
self.conv2 = GCNConv(hidden_ch, out_ch)
|
||||
self.norm1 = nn.LayerNorm(hidden_ch)
|
||||
self.norm2 = nn.LayerNorm(out_ch)
|
||||
self.single_edge_index, _ = grid(h, w) # [2, E] for a single map
|
||||
|
||||
def forward(self, x):
|
||||
# x: [B, C, H, W]
|
||||
B, C, H, W = x.shape
|
||||
|
||||
# Reshape to [B * H * W, C]
|
||||
x = x.permute(0, 2, 3, 1).reshape(B * H * W, C)
|
||||
|
||||
# Construct batched edge index
|
||||
device = x.device
|
||||
edge_index = self._batch_edge_index(B, self.single_edge_index.to(device), H * W)
|
||||
|
||||
# Batch vector for PyG (not strictly needed for GCNConv, but useful if you switch to GAT/Pooling)
|
||||
# batch = torch.arange(B, device=device).repeat_interleave(H * W)
|
||||
|
||||
# GCN forward
|
||||
x = self.conv1(x, edge_index)
|
||||
x = F.elu(self.norm1(x))
|
||||
x = self.conv2(x, edge_index)
|
||||
x = F.elu(self.norm2(x))
|
||||
|
||||
# Reshape back to [B, C, H, W]
|
||||
x = x.view(B, H, W, -1).permute(0, 3, 1, 2)
|
||||
return x
|
||||
|
||||
def _batch_edge_index(self, B, edge_index, num_nodes_per_batch):
|
||||
# 批次偏移 edge_index
|
||||
edge_index = edge_index.clone() # [2, E]
|
||||
batch_edge_index = []
|
||||
for i in range(B):
|
||||
offset = i * num_nodes_per_batch
|
||||
batch_edge_index.append(edge_index + offset)
|
||||
return torch.cat(batch_edge_index, dim=1)
|
||||
|
||||
class FusionModule(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 1),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU()
|
||||
)
|
||||
self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
|
||||
|
||||
def forward(self, x1, x2):
|
||||
x = torch.cat([x1, x2], dim=1)
|
||||
|
||||
@ -4,17 +4,16 @@ import sys
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
import cv2
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .model.model import GinkaModel
|
||||
from .dataset import GinkaWGANDataset
|
||||
from .model.loss import WGANGinkaLoss
|
||||
from .model.input import RandomInputHead
|
||||
from minamo.model.model import MinamoScoreModule
|
||||
from minamo.model.similarity import MinamoSimilarityModel
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
from shared.image import matrix_to_image_cv
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
|
||||
BATCH_SIZE = 16
|
||||
|
||||
@ -67,17 +66,15 @@ def train():
|
||||
|
||||
c_steps = 5
|
||||
g_steps = 1
|
||||
# 1 代表课程学习阶段,2 代表课程学习后,逐渐转为联合学习的阶段
|
||||
# 3 代表课程学习后的联合遮挡学习阶段,4 代表最后随机输入的联合学习阶段
|
||||
# 训练阶段
|
||||
train_stage = 1
|
||||
last_stage = False
|
||||
mask_ratio = 0.2 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
||||
random_ratio = 0
|
||||
stage3_epoch = 0 # 第三阶段 epoch 数,若干轮后进入第四阶段
|
||||
stage_epoch = 0 # 记录当前阶段的 epoch 数,用于控制训练过程
|
||||
|
||||
ginka = GinkaModel()
|
||||
minamo = MinamoScoreModule()
|
||||
ginka.to(device)
|
||||
minamo.to(device)
|
||||
ginka = GinkaModel().to(device)
|
||||
ginka_head = RandomInputHead().to(device)
|
||||
minamo = MinamoScoreModule().to(device)
|
||||
|
||||
dataset = GinkaWGANDataset(args.train, device)
|
||||
dataset_val = GinkaWGANDataset(args.validate, device)
|
||||
@ -85,6 +82,7 @@ def train():
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_head = optim.Adam(ginka_head.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
||||
|
||||
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
||||
@ -112,15 +110,15 @@ def train():
|
||||
if data_ginka.get("mask_ratio") is not None:
|
||||
mask_ratio = data_ginka["mask_ratio"]
|
||||
|
||||
if data_ginka.get("random_ratio") is not None:
|
||||
random_ratio = data_ginka["random_ratio"]
|
||||
|
||||
if data_ginka.get("stage_epoch3") is not None:
|
||||
stage3_epoch = data_ginka["stage_epoch3"]
|
||||
if data_ginka.get("stage_epoch") is not None:
|
||||
stage_epoch = data_ginka["stage_epoch"]
|
||||
|
||||
if data_ginka.get("stage") is not None:
|
||||
train_stage = data_ginka["stage"]
|
||||
|
||||
if data_ginka.get("last_stage") is not None:
|
||||
last_stage = data_ginka["last_stage"]
|
||||
|
||||
if args.load_optim:
|
||||
if data_ginka.get("optim_state") is not None:
|
||||
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
||||
@ -131,13 +129,11 @@ def train():
|
||||
dataset.mask_ratio1 = mask_ratio
|
||||
dataset.mask_ratio2 = mask_ratio
|
||||
dataset.mask_ratio3 = mask_ratio
|
||||
dataset.random_ratio = random_ratio
|
||||
|
||||
dataset_val.train_stage = train_stage
|
||||
dataset_val.mask_ratio1 = mask_ratio
|
||||
dataset_val.mask_ratio2 = mask_ratio
|
||||
dataset_val.mask_ratio3 = mask_ratio
|
||||
dataset_val.random_ratio = random_ratio
|
||||
|
||||
print("Train from loaded state.")
|
||||
|
||||
@ -152,16 +148,34 @@ def train():
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
|
||||
|
||||
if train_stage == 4:
|
||||
# 最后一个阶段训练输入头
|
||||
count = 5 if stage_epoch <= 20 else 2
|
||||
for _ in range(count):
|
||||
optimizer_head.zero_grad()
|
||||
output = F.softmax(ginka_head(masked1), dim=1)
|
||||
loss_head = criterion.generator_input_head_loss(output)
|
||||
loss_head.backward()
|
||||
optimizer_head.step()
|
||||
|
||||
# ---------- 训练判别器
|
||||
for _ in range(c_steps):
|
||||
# 生成假样本
|
||||
optimizer_minamo.zero_grad()
|
||||
optimizer_ginka.zero_grad()
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||
optimizer_head.zero_grad()
|
||||
|
||||
with torch.no_grad():
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||
|
||||
elif train_stage == 3:
|
||||
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
|
||||
elif train_stage == 4:
|
||||
input = F.softmax(ginka_head(masked1), dim=1)
|
||||
fake1, fake2, fake3 = gen_total(ginka, input, True, True)
|
||||
|
||||
|
||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
|
||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
|
||||
@ -183,6 +197,7 @@ def train():
|
||||
for _ in range(g_steps):
|
||||
optimizer_minamo.zero_grad()
|
||||
optimizer_ginka.zero_grad()
|
||||
optimizer_head.zero_grad()
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False)
|
||||
|
||||
@ -199,10 +214,12 @@ def train():
|
||||
loss_ce_total += loss_ce.detach()
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3 = gen_total(ginka, masked1, True, False)
|
||||
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
|
||||
|
||||
fake1, fake2, fake3 = gen_total(ginka, input, True, False)
|
||||
|
||||
if train_stage == 3:
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, masked1)
|
||||
loss_g1 = criterion.generator_loss_total_with_input(minamo, 1, fake1, input)
|
||||
else:
|
||||
loss_g1 = criterion.generator_loss_total(minamo, 1, fake1)
|
||||
loss_g2 = criterion.generator_loss_total_with_input(minamo, 2, fake2, fake1)
|
||||
@ -221,43 +238,42 @@ def train():
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||
f"Epoch: {epoch + 1} | S: {train_stage} | W: {avg_dis:.6f} | " +
|
||||
f"G: {avg_loss_ginka:.6f} | D: {avg_loss_minamo:.6f} | " +
|
||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f} | R: {random_ratio:.1f}"
|
||||
f"CE: {avg_loss_ce:.6f} | M: {mask_ratio:.1f}"
|
||||
)
|
||||
|
||||
if avg_loss_ce < 0.5:
|
||||
if avg_loss_ce < 1.0:
|
||||
low_loss_epochs += 1
|
||||
else:
|
||||
low_loss_epochs = 0
|
||||
|
||||
if low_loss_epochs >= 3 and train_stage == 2:
|
||||
if random_ratio >= 0.5:
|
||||
train_stage = 3
|
||||
random_ratio += 0.2
|
||||
random_ratio = min(random_ratio, 0.5)
|
||||
low_loss_epochs = 0
|
||||
|
||||
if low_loss_epochs >= 3 and train_stage == 1:
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
stage_epoch = 0
|
||||
mask_ratio += 0.2
|
||||
mask_ratio = min(mask_ratio, 0.9)
|
||||
low_loss_epochs = 0
|
||||
|
||||
if train_stage == 3:
|
||||
stage3_epoch += 1
|
||||
# 十轮足够了
|
||||
if stage3_epoch >= 10:
|
||||
train_stage = 4
|
||||
stage3_epoch = 0
|
||||
if train_stage == 3 or train_stage == 2:
|
||||
if stage_epoch >= 25:
|
||||
train_stage += 1
|
||||
stage_epoch = 0
|
||||
|
||||
if train_stage >= 2:
|
||||
# 第二阶段后 L1 损失不再应该生效
|
||||
if train_stage >= 3:
|
||||
# 第三阶段后交叉熵损失不再应该生效
|
||||
mask_ratio = 1.0
|
||||
|
||||
if last_stage:
|
||||
if train_stage == 2 and stage_epoch % 5 == 0:
|
||||
train_stage = 4
|
||||
|
||||
if train_stage == 4 and stage_epoch % 5 == 1:
|
||||
train_stage = 2
|
||||
|
||||
stage_epoch += 1
|
||||
|
||||
dataset.train_stage = train_stage
|
||||
dataset_val.train_stage = train_stage
|
||||
dataset.random_ratio = random_ratio
|
||||
dataset_val.random_ratio = random_ratio
|
||||
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
||||
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
||||
|
||||
@ -284,8 +300,8 @@ def train():
|
||||
"g_steps": g_steps,
|
||||
"stage": train_stage,
|
||||
"mask_ratio": mask_ratio,
|
||||
"random_ratio": random_ratio,
|
||||
"stage3_epoch": stage3_epoch,
|
||||
"stage_epoch": stage_epoch,
|
||||
"last_stage": last_stage
|
||||
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict(),
|
||||
@ -300,7 +316,8 @@ def train():
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||
|
||||
elif train_stage == 3 or train_stage == 4:
|
||||
fake1, fake2, fake3 = gen_total(ginka, masked1, True, True)
|
||||
input = masked1 if train_stage == 3 else F.softmax(ginka_head(masked1), dim=1)
|
||||
fake1, fake2, fake3 = gen_total(ginka, input, True, True)
|
||||
|
||||
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
||||
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
||||
|
||||
BIN
tiles/13.png
Normal file
BIN
tiles/13.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 419 B |
2
train.sh
2
train.sh
@ -1,4 +1,4 @@
|
||||
# 从头训练
|
||||
python3 -u -m ginka.train_wgan >> output.log
|
||||
# 接续训练
|
||||
python3 -u -m ginka.train_wgan --resume true --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log
|
||||
python3 -u -m ginka.train_wgan --resume true --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log
|
||||
|
||||
Loading…
Reference in New Issue
Block a user