refactor: 课程训练

This commit is contained in:
unanmed 2025-04-13 21:06:07 +08:00
parent 99f46150be
commit f6b1ad6ebd
11 changed files with 363 additions and 843 deletions

View File

@ -3,10 +3,18 @@ import random
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from minamo.model.model import MinamoModel
from shared.graph import differentiable_convert_to_data
import torch
import torch.nn.functional as F
from typing import List
from shared.utils import random_smooth_onehot
STAGE1_MASK = [0, 1, 10, 11]
STAGE1_REMOVE = [2, 3, 4, 5, 6, 7, 8, 9, 12]
STAGE2_MASK = [6, 7, 8, 9]
STAGE2_REMOVE = [2, 3, 4, 5, 12]
STAGE3_MASK = [2, 3, 4, 5, 12]
STAGE3_REMOVE = []
def load_data(path: str):
with open(path, 'r', encoding="utf-8") as f:
data = json.load(f)
@ -23,38 +31,45 @@ def load_minamo_gan_data(data: list):
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
return res
class GinkaDataset(Dataset):
def __init__(self, data_path: str, device, minamo: MinamoModel):
self.data = load_data(data_path) # 自定义数据加载函数
self.max_size = 32
self.minamo = minamo
self.device = device
def apply_curriculum_mask(
maps: torch.Tensor, # [B, C, H, W]
mask_classes: List[int], # 要遮挡的类别索引
remove_classes: List[int], # 要移除的类别索引
mask_ratio: float # 遮挡比例 0~1
) -> torch.Tensor:
C, H, W = maps.shape
device = maps.device
masked_maps = maps.clone()
def __len__(self):
return len(self.data)
# Step 1: 移除不需要的类别(全设为 0 类)
if remove_classes:
remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
removed_maps = masked_maps.clone()
def __getitem__(self, idx):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
min_main = random.uniform(0.75, 0.9)
max_main = random.uniform(0.9, 1)
epsilon = random.uniform(0, 0.25)
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon)
graph = differentiable_convert_to_data(target_smooth).to(self.device)
target = target.to(self.device)
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
return {
"target_vision_feat": vision_feat,
"target_topo_feat": topo_feat,
"target": target,
}
# Step 2: 对指定类别随机遮挡
for cls in mask_classes:
cls_mask = masked_maps[:, cls] > 0 # 目标类别的像素布尔掩码 [H, W]
indices = cls_mask.nonzero(as_tuple=False) # 所有该类像素坐标
num_mask = int(len(indices) * mask_ratio)
if num_mask > 0:
selected = indices[torch.randperm(len(indices))[:num_mask]]
masked_maps[cls, selected[:, 0], selected[:, 1]] = 0
masked_maps[0, selected[:, 0], selected[:, 1]] = 1 # 置为“空地”
return removed_maps, masked_maps
class GinkaWGANDataset(Dataset):
def __init__(self, data_path: str, device):
self.data = load_data(data_path) # 自定义数据加载函数
self.device = device
self.train_stage = 1
self.mask_ratio1 = 0.1
self.mask_ratio2 = 0.1
self.mask_ratio3 = 0.1
self.random_ratio = 0.0
def __len__(self):
return len(self.data)
@ -63,56 +78,20 @@ class GinkaWGANDataset(Dataset):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
# min_main = random.uniform(0.8, 0.9)
# max_main = random.uniform(0.9, 1)
# epsilon = random.uniform(0, 0.2)
# target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device)
if self.train_stage == 1:
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
elif self.train_stage == 2:
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 0.9))
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 0.9))
return target
class MinamoGANDataset(Dataset):
def __init__(self, refer_data_path):
self.refer = load_minamo_gan_data(load_data(refer_data_path))
self.data = list()
self.data.extend(random.sample(self.refer, 1000))
def set_data(self, data: list):
self.data.clear()
self.data.extend(data)
k = min(len(data) / 4, len(self.refer))
self.data.extend(random.sample(self.refer, int(k)))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
# 假定 map2 是参考地图
item = self.data[idx]
map1, map2, vis_sim, topo_sim, review = item
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
if review:
map1 = F.one_hot(torch.LongTensor(map1), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
else:
map1 = torch.FloatTensor(map1)
map2 = F.one_hot(torch.LongTensor(map2), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
min_main = random.uniform(0.75, 0.9)
max_main = random.uniform(0.9, 1)
epsilon = random.uniform(0, 0.25)
if review:
map1 = random_smooth_onehot(map1, min_main, max_main, epsilon)
map2 = random_smooth_onehot(map2, min_main, max_main, epsilon)
graph1 = differentiable_convert_to_data(map1)
graph2 = differentiable_convert_to_data(map2)
return (
map1,
map2,
torch.FloatTensor([vis_sim]),
torch.FloatTensor([topo_sim]),
graph1,
graph2
)
if self.random_ratio > 0:
removed1 = random_smooth_onehot(removed1, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
removed2 = random_smooth_onehot(removed2, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
removed3 = random_smooth_onehot(removed3, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
return removed1, masked1, removed2, masked2, removed3, masked3

View File

@ -2,43 +2,24 @@ import torch
import torch.nn as nn
class GinkaInput(nn.Module):
def __init__(self, feat_dim=1024, out_ch=1, size=(32, 32)):
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
super().__init__()
self.out_size = out_size
self.fc = nn.Sequential(
nn.Linear(feat_dim, size[0] * size[1] * out_ch),
nn.Unflatten(1, (out_ch, *size))
nn.Linear(in_size[0] * in_size[1], out_size[0] * out_size[1]),
nn.LayerNorm(out_size[0] * out_size[1]),
nn.ELU()
)
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
nn.ELU()
)
def forward(self, x):
B, C, H, W = x.shape
x = x.view(B, C, H * W)
x = self.fc(x)
x = x.view(B, C, self.out_size[0], self.out_size[1])
x = self.conv(x)
return x
class FeatureEncoder(nn.Module):
def __init__(self, feat_dim, size, mid_ch, out_ch):
super().__init__()
self.encode = nn.Sequential(
nn.Linear(feat_dim, mid_ch * size * size),
nn.Unflatten(1, (mid_ch, size, size)),
nn.Conv2d(mid_ch, out_ch, 1)
)
def forward(self, x):
x = self.encode(x)
return x
class GinkaFeatureInput(nn.Module):
def __init__(self, feat_dim=1024, mid_ch=1, out_ch=64):
super().__init__()
self.encode1 = FeatureEncoder(feat_dim, 32, mid_ch, out_ch)
self.encode2 = FeatureEncoder(feat_dim, 16, mid_ch * 2, out_ch * 2)
self.encode3 = FeatureEncoder(feat_dim, 8, mid_ch * 4, out_ch * 4)
self.encode4 = FeatureEncoder(feat_dim, 4, mid_ch * 8, out_ch * 8)
self.encode5 = FeatureEncoder(feat_dim, 2, mid_ch * 16, out_ch * 16)
def forward(self, x):
x1 = self.encode1(x)
x2 = self.encode2(x)
x3 = self.encode3(x)
x4 = self.encode4(x)
x5 = self.encode5(x)
return x1, x2, x3, x4, x5

View File

@ -13,6 +13,13 @@ from shared.similarity.vision import calculate_visual_similarity
CLASS_NUM = 32
ILLEGAL_MAX_NUM = 12
STAGE_ALLOWED = [
[],
[0, 1, 10, 11],
[6, 7, 8, 9,],
[2, 3, 4, 5, 12]
]
def get_not_allowed(classes: list[int], include_illegal=False):
res = list()
for num in range(0, CLASS_NUM):
@ -301,24 +308,47 @@ def js_divergence(p, q, eps=1e-8):
return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0)
def immutable_penalty_loss(
pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
) -> torch.Tensor:
"""
惩罚模型修改不可更改区域的损失
Args:
input: 模型输出 [B, C, H, W]概率分布 (softmax )
target: 原始输入图 [B, C, H, W]概率分布 (softmax )
modifiable_classes: 允许被修改的类别列表
penalty_weight: 对非允许修改区域的惩罚系数
"""
not_allowed = get_not_allowed(modifiable_classes, include_illegal=True)
input_mask = pred[:, not_allowed, :, :]
with torch.no_grad():
target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1)
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
# 差异区域(模型试图改变的地方)
penalty = F.cross_entropy(input_mask, target_mask)
return penalty
class WGANGinkaLoss:
def __init__(self, lambda_gp=100, weight=[0.8, 0.1, 0.1], diversity_lamda=0.4):
def __init__(self, lambda_gp=100, weight=[1, 0.4, 10, 0.2, 0.2]):
# weight: 判别器损失L1 损失,不可修改类型损失
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
self.diversity_lamda = diversity_lamda
def compute_gradient_penalty(self, critic, real_data, fake_data):
def compute_gradient_penalty(self, critic, stage, real_data, fake_data):
# 进行插值
batch_size = real_data.size(0)
epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device)
interp_data = interpolate_data(real_data, fake_data, epsilon_data)
interp_graph = batch_convert_soft_map_to_graph(interp_data)
interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device)
interp_graph = batch_convert_soft_map_to_graph(interp_data).to(real_data.device)
# 对图像进行反向传播并计算梯度
interp_data.requires_grad_()
interp_graph.x.requires_grad_()
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph)
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage)
# 计算梯度
grad_vis = torch.autograd.grad(
@ -344,21 +374,21 @@ class WGANGinkaLoss:
return gp_loss
def discriminator_loss(
self, critic, real_data: torch.Tensor,
real_graph: torch.Tensor, fake_data: torch.Tensor
):
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
""" 判别器损失函数 """
real_graph = batch_convert_soft_map_to_graph(real_data)
fake_graph = batch_convert_soft_map_to_graph(fake_data)
real_scores, _, _ = critic(real_data, real_graph)
fake_scores, _, _ = critic(fake_data, fake_graph)
# print("Critic 输出范围", fake_scores.min().item(), fake_scores.max().item(), real_scores.min().item(), real_scores.max().item())
real_scores, _, _ = critic(real_data, real_graph, stage)
fake_scores, _, _ = critic(fake_data, fake_graph, stage)
# Wasserstein 距离
d_loss = fake_scores.mean() - real_scores.mean()
grad_loss = self.compute_gradient_penalty(critic, real_data, fake_data)
grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data)
return d_loss, d_loss + self.lambda_gp * grad_loss
total_loss = d_loss + self.lambda_gp * grad_loss
return total_loss, d_loss
def calculate_similarity_one(self, map1, map2):
topo1 = build_topological_graph(map1)
@ -368,73 +398,29 @@ class WGANGinkaLoss:
topo_sim = overall_similarity(topo1, topo2)
return vis_sim, topo_sim
def discriminator_loss_assist(self, critic, fake_data1, fake_data2):
graph1 = batch_convert_soft_map_to_graph(fake_data1)
graph2 = batch_convert_soft_map_to_graph(fake_data2)
vis_feat_1, topo_feat_1 = critic(fake_data1, graph1)
vis_feat_2, topo_feat_2 = critic(fake_data2, graph2)
def generator_loss(self, critic, stage, mask_ratio, real, fake, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" 生成器损失函数 """
fake_graph = batch_convert_soft_map_to_graph(fake)
batch1 = torch.argmax(fake_data1, dim=1).cpu().tolist()
batch2 = torch.argmax(fake_data2, dim=1).cpu().tolist()
vis_sim_real = []
topo_sim_real = []
for i in range(len(batch1)):
vis_sim, topo_sim = self.calculate_similarity_one(batch1[i], batch2[i])
vis_sim_real.append(vis_sim)
topo_sim_real.append(topo_sim)
vis_sim_real = torch.Tensor(vis_sim_real)
topo_sim_real = torch.Tensor(topo_sim_real)
pred_vis_sim = F.cosine_similarity(vis_feat_1, vis_feat_2).cpu()
pred_topo_sim = F.cosine_similarity(topo_feat_1, topo_feat_2).cpu()
loss1 = F.l1_loss(pred_vis_sim, vis_sim_real) * VISION_WEIGHT + F.l1_loss(pred_topo_sim, topo_sim_real) * TOPO_WEIGHT
return loss1
def discriminator_loss_assist2(self, critic, real_data, fake_data1, fake_data2):
loss1 = self.discriminator_loss_assist(critic, real_data, fake_data1)
loss2 = self.discriminator_loss_assist(critic, real_data, fake_data2)
loss3 = self.discriminator_loss_assist(critic, fake_data1, fake_data2)
return loss1 / 3.0 + loss2 / 3.0 + loss3 / 3.0
def generator_loss_one(self, critic, fake, fake_graph):
fake_scores, _, _ = critic(fake, fake_graph)
fake_scores, _, _ = critic(fake, fake_graph, stage)
minamo_loss = -torch.mean(fake_scores)
class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
entrance_loss = entrance_constraint_loss(fake)
ce_loss = F.cross_entropy(fake, real)
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
losses = [
minamo_loss * self.weight[0],
class_loss * self.weight[1],
entrance_loss * self.weight[2]
ce_loss * self.weight[1] / mask_ratio * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
immutable_loss * self.weight[2],
constraint_loss * self.weight[3]
]
return sum(losses)
def generator_loss(self, critic, critic_assist, fake1, fake2):
""" 生成器损失函数 """
fake_graph1 = batch_convert_soft_map_to_graph(fake1)
fake_graph2 = batch_convert_soft_map_to_graph(fake2)
if stage == 1:
# 第一个阶段检查入口存在性
entrance_loss = entrance_constraint_loss(fake)
losses.append(entrance_loss * self.weight[4])
loss1 = self.generator_loss_one(critic, fake1, fake_graph1)
loss2 = self.generator_loss_one(critic, fake2, fake_graph2)
# print(losses[2].item())
# vis_feat1, topo_feat1 = critic_assist(fake1, fake_graph1)
# vis_feat2, topo_feat2 = critic_assist(fake2, fake_graph2)
# vis_sim = F.cosine_similarity(vis_feat1, vis_feat2)
# topo_sim = F.cosine_similarity(topo_feat1, topo_feat2)
# similarity = vis_sim * VISION_WEIGHT + topo_sim * TOPO_WEIGHT
# print(similarity.mean().item())
# div_loss = F.l1_loss(fake1[:, :, 1:-1, 1:-1], fake2[:, :, 1:-1, 1:-1])
return loss1 * 0.5 + loss2 * 0.5\
# + self.diversity_lamda * F.relu(0.7 - div_loss).mean()
# + self.diversity_lamda * F.relu(similarity - 0.4).mean()
return sum(losses), minamo_loss, ce_loss / mask_ratio, immutable_loss

View File

@ -9,27 +9,29 @@ def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
class GinkaModel(nn.Module):
def __init__(self, feat_dim=1024, base_ch=64, out_ch=32):
def __init__(self, base_ch=64, out_ch=32):
"""Ginka Model 模型定义部分
"""
super().__init__()
self.unet = GinkaUNet(base_ch, base_ch, feat_dim)
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
self.unet = GinkaUNet(32, base_ch, base_ch)
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
def forward(self, x):
def forward(self, x, stage):
"""
Args:
x: 参考地图的特征向量
Returns:
logits: 输出logits [BS, num_classes, H, W]
"""
x = self.input(x)
x = self.unet(x)
x = self.output(x)
x = self.output(x, stage)
return F.softmax(x, dim=1)
# 检查显存占用
if __name__ == "__main__":
feat = torch.randn((1, 1024)).cuda()
input = torch.randn((1, 32, 13, 13)).cuda()
# 初始化模型
model = GinkaModel().cuda()
@ -37,14 +39,13 @@ if __name__ == "__main__":
print_memory("初始化后")
# 前向传播
output = model(feat)
output = model(input, 1)
print_memory("前向传播后")
print(f"输入形状: feat={feat.shape}")
print(f"输入形状: feat={input.shape}")
print(f"输出形状: output={output.shape}")
# print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
# print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,13 +1,42 @@
import torch
import torch.nn as nn
class GinkaOutput(nn.Module):
def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)):
class StageHead(nn.Module):
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
super().__init__()
self.conv_down = nn.Sequential(
self.head = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(in_ch),
nn.ELU(),
nn.Conv2d(in_ch, in_ch, 1),
nn.InstanceNorm2d(in_ch),
nn.ELU(),
)
self.pool = nn.Sequential(
nn.AdaptiveMaxPool2d(out_size),
nn.Conv2d(in_ch, out_ch, 1)
)
def forward(self, x):
return self.conv_down(x)
x = self.head(x)
x = self.pool(x)
return x
class GinkaOutput(nn.Module):
def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)):
super().__init__()
self.head1 = StageHead(in_ch, out_ch, out_size)
self.head2 = StageHead(in_ch, out_ch, out_size)
self.head3 = StageHead(in_ch, out_ch, out_size)
def forward(self, x, stage):
if stage == 1:
x = self.head1(x)
elif stage == 2:
x = self.head2(x)
elif stage == 3:
x = self.head3(x)
else:
raise RuntimeError("Unknown generate stage.")
return x

View File

@ -198,15 +198,15 @@ class GinkaBottleneck(nn.Module):
return x
class GinkaUNet(nn.Module):
def __init__(self, base_ch=64, out_ch=32, feat_dim=1024):
def __init__(self, in_ch=32, base_ch=64, out_ch=32):
"""Ginka Model UNet 部分
"""
super().__init__()
self.input = GinkaTransformerEncoder(
in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
token_size=4, ff_dim=feat_dim*2, num_layers=4
)
self.down1 = ConvBlock(2, base_ch)
# self.input = GinkaTransformerEncoder(
# in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
# token_size=4, ff_dim=feat_dim*2, num_layers=4
# )
self.down1 = ConvBlock(in_ch, base_ch)
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
@ -223,10 +223,6 @@ class GinkaUNet(nn.Module):
)
def forward(self, x):
B, D = x.shape # [B, 1024]
x = x.view(B, 4, D // 4) # [B, 4, 256]
x = self.input(x) # [B, 4, 512]
x = x.view(B, 2, 32, 32) # [B, 2, 32, 32]
x1 = self.down1(x) # [B, 64, 32, 32]
x2 = self.down2(x1) # [B, 128, 16, 16]
x3 = self.down3(x2) # [B, 256, 8, 8]
@ -237,5 +233,6 @@ class GinkaUNet(nn.Module):
x = self.up1(x4, x3) # [B, 256, 8, 8]
x = self.up2(x, x2) # [B, 128, 16, 16]
x = self.up3(x, x1) # [B, 64, 32, 32]
x = self.final(x) # [B, 32, 32, 32]
return self.final(x) # [B, 32, 32, 32]
return x

View File

@ -1,134 +0,0 @@
import os
from datetime import datetime
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from .model.model import GinkaModel
from .model.loss import GinkaLoss
from .dataset import GinkaDataset
from minamo.model.model import MinamoModel
from shared.args import parse_arguments
BATCH_SIZE = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/ginka_checkpoint", exist_ok=True)
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
model = GinkaModel()
model.to(device)
minamo = MinamoModel(32)
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
minamo.to(device)
minamo.eval()
# 准备数据集
dataset = GinkaDataset(args.train, device, minamo)
dataset_val = GinkaDataset(args.validate, device, minamo)
dataloader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True
)
dataloader_val = DataLoader(
dataset_val,
batch_size=BATCH_SIZE,
shuffle=True
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo)
if args.resume:
data = torch.load(args.from_state, map_location=device)
model.load_state_dict(data["model_state"], strict=False)
if args.load_optim:
optimizer.load_state_dict(data["optimizer_state"])
print("Train from loaded state.")
else:
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
criterion.weight[0] = 0.0
# 开始训练
for epoch in tqdm(range(args.epochs)):
model.train()
total_loss = 0
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
if not args.resume and epoch == 10:
criterion.weight[0] = 0.5
for batch in dataloader:
# 数据迁移到设备
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
optimizer.zero_grad()
_, output_softmax = model(feat_vec)
# 计算损失
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
# 反向传播
losses.backward()
optimizer.step()
total_loss += losses.item()
avg_loss = total_loss / len(dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
# for name, param in model.named_parameters():
# if param.grad is not None:
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
# 学习率调整
scheduler.step()
if (epoch + 1) % 5 == 0:
loss_val = 0
model.eval()
with torch.no_grad():
for batch in dataloader_val:
# 数据迁移到设备
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device)
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
output, output_softmax = model(feat_vec)
print(torch.argmax(output, dim=1)[0])
# 计算损失
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
loss_val += losses.item()
avg_val_loss = loss_val / len(dataloader_val)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({
"model_state": model.state_dict(),
# "optimizer_state": optimizer.state_dict(),
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
print("Train ended.")
torch.save({
"model_state": model.state_dict(),
# "optimizer_state": optimizer.state_dict(),
}, f"result/ginka.pth")
if __name__ == "__main__":
torch.set_num_threads(4)
train()

View File

@ -1,410 +0,0 @@
import argparse
import socket
import struct
import os
from datetime import datetime
import torch
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import cv2
import numpy as np
from .model.model import GinkaModel
from .model.loss import GinkaLoss, WGANGinkaLoss
from .dataset import GinkaDataset, MinamoGANDataset
from minamo.model.model import MinamoModel
from minamo.model.loss import MinamoLoss
from shared.image import matrix_to_image_cv
BATCH_SIZE = 32
EPOCHS_GINKA = 5
EPOCHS_MINAMO = 2
SOCKET_PATH = "./tmp/ginka_uds"
LOSS_PATH = "result/gan/a-loss.txt"
REPLAY_PATH = "datasets/replay.bin"
VISION_ALPHA = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/ginka_checkpoint", exist_ok=True)
os.makedirs("result/gan", exist_ok=True)
os.makedirs("tmp", exist_ok=True)
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
f.write(f"---------- {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ----------\n")
if not os.path.exists(REPLAY_PATH):
with open(REPLAY_PATH, 'wb') as f:
f.write(b'\x00\x00\x00\x00')
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--from_state", type=str, default="result/ginka.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default='ginka-eval.json')
parser.add_argument("--from_cycle", type=int, default=0)
parser.add_argument("--to_cycle", type=int, default=100)
args = parser.parse_args()
return args
def parse_ginka_batch(batch):
target = batch["target"].to(device)
target_vision_feat = batch["target_vision_feat"].to(device).squeeze(1)
target_topo_feat = batch["target_topo_feat"].to(device).squeeze(1)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=1).to(device)
return target, target_vision_feat, target_topo_feat, feat_vec
def parse_minamo_batch(batch):
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
map1 = map1.to(device) # 转为 [B, C, H, W]
map2 = map2.to(device)
topo_simi = topo_simi.to(device)
vision_simi = vision_simi.to(device)
graph1 = graph1.to(device)
graph2 = graph2.to(device)
return map1, map2, vision_simi, topo_simi, graph1, graph2
def send_all(sock, data):
total_sent = 0
while total_sent < len(data):
sent = sock.send(data[total_sent:])
if sent == 0:
raise RuntimeError("Socket connection broken")
total_sent += sent
def recv_all(sock: socket.socket, length: int):
"""循环接收直到获得指定长度的数据"""
data = bytes()
while len(data) < length:
packet = sock.recv(length - len(data)) # 只请求剩余部分
if not packet:
raise ConnectionError("连接中断")
data += packet
return data
def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
# 数据通讯 node 输出协议,单位字节:
# 2 - Tensor count; 2 - Review count. Review is right behind train data;
# 1*tc - Compare count for every map tensor delivered.
# 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
# N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor.
_, _, H, W = maps.shape
tc_buf = sock.recv(2)
rc_buf = sock.recv(2)
tc = struct.unpack('>h', tc_buf)[0]
rc = struct.unpack('>h', rc_buf)[0]
count_buf = recv_all(sock, 1 * tc)
count: list = struct.unpack(f">{tc}b", count_buf)
N = sum(count)
sim_buf = recv_all(sock, 2 * 4 * (N + rc))
com_buf = recv_all(sock, N * 1 * H * W)
review_buf = recv_all(sock, rc * 2 * H * W) if rc > 0 else bytes()
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)
review = struct.unpack(f">{rc * 2 * H * W}", review_buf) if rc > 0 else list()
res = list()
flatten_idx = 0
# 读取当前这一轮生成器的数据
for idx in range(tc):
com_count = count[idx]
for i in range(com_count):
com_start = flatten_idx * H * W
com_end = (flatten_idx + 1) * H * W
vis_sim = sim[flatten_idx * 2]
topo_sim = sim[flatten_idx * 2 + 1]
com_data = com[com_start:com_end]
flatten_idx += 1
com_map = np.array(com_data, dtype=np.int8).reshape(H, W)
# map1, map2, vision_similarity, topo_similarity, is_review
res.append((maps[idx], com_map, vis_sim, topo_sim, False))
return res
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments()
ginka = GinkaModel()
ginka.to(device)
minamo = MinamoModel(32)
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
minamo.to(device)
minamo.eval()
# 准备数据集
ginka_dataset = GinkaDataset(args.train, device, minamo)
ginka_dataset_val = GinkaDataset(args.validate, device, minamo)
minamo_dataset = MinamoGANDataset("datasets/minamo-dataset-1.json")
minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json")
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE // 2, shuffle=True)
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE // 2, shuffle=True)
# 设定优化器与调度器
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
criterion_ginka = GinkaLoss(minamo)
optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9))
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=EPOCHS_MINAMO, T_mult=2, eta_min=1e-6)
criterion_minamo = MinamoLoss()
criterion = WGANGinkaLoss()
# 用于生成图片
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
# 与 node 端通讯
if os.path.exists(SOCKET_PATH):
os.remove(SOCKET_PATH)
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
server.bind(SOCKET_PATH)
server.listen(1)
if args.resume:
data = torch.load(args.from_state, map_location=device)
ginka.load_state_dict(data["model_state"], strict=False)
print("Train from loaded state.")
print("Waiting for client connection...")
conn, _ = server.accept()
print("Client connected.")
for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"):
# -------------------- 训练生成器
for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model", leave=False):
ginka.train()
minamo.eval()
total_loss = 0
for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"):
# 数据迁移到设备
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
# 前向传播
optimizer_ginka.zero_grad()
_, output_softmax = ginka(feat_vec)
# 计算损失
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
# 反向传播
losses.backward()
optimizer_ginka.step()
total_loss += losses.item()
avg_loss = total_loss / len(ginka_dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
# 学习率调整
scheduler_ginka.step(epoch + 1)
if (epoch + 1) % 5 == 0:
loss_val = 0
ginka.eval()
idx = 0
with torch.no_grad():
for batch in tqdm(ginka_dataloader_val, leave=False, desc="Validating Ginka Model"):
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
output, output_softmax = ginka(feat_vec)
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
loss_val += losses.item()
if epoch + 1 == EPOCHS_GINKA:
# 最后一次验证的时候顺带生成图片
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
for matrix in map_matrix:
image = matrix_to_image_cv(matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
idx += 1
avg_val_loss = loss_val / len(ginka_dataloader_val)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({
"model_state": ginka.state_dict()
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
# 使用训练集生成 minamo 训练数据,更准确
gen_list: np.ndarray = np.empty((0, 13, 13), np.int8)
prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32)
with torch.no_grad():
for batch in ginka_dataloader:
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
output, output_softmax = ginka(feat_vec)
prob = output_softmax.cpu().numpy()
prob_list = np.concatenate((prob_list, prob), axis=0)
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
tqdm.write(f"Cycle {cycle} Ginka train ended.")
torch.save({
"model_state": ginka.state_dict()
}, f"result/gan/ginka-{cycle}.pth")
torch.save({
"model_state": ginka.state_dict()
}, f"result/ginka.pth")
# -------------------- 生成 Minamo 的训练数据
# 数据通讯 python 输出协议,单位字节:
# 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
N, H, W = gen_list.shape
gen_bytes = gen_list.astype(np.int8).tobytes()
buf = bytearray()
buf.extend(struct.pack('>h', N)) # Tensor count
buf.extend(struct.pack('>b', H)) # Map height
buf.extend(struct.pack('>b', W)) # Map width
buf.extend(gen_bytes) # Map tensor
conn.sendall(buf)
data = parse_minamo_data(conn, prob_list)
vis_sim = 0
topo_sim = 0
for _, _, vis, topo, _ in data:
vis_sim += vis
topo_sim += topo
vis_sim /= len(data)
topo_sim /= len(data)
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
f.write(f'Cycle {cycle} | Ginka Vision Similarity: {vis_sim:.12f} | Ginka Topo Similarity: {topo_sim:.12f} | Ginka Loss: {avg_val_loss:.12f}')
# 经验回放部分
with open(REPLAY_PATH, 'r+b') as f:
# 读取文件开头获取总长度
f.seek(0)
count = struct.unpack('>i', f.read(4))[0] # 取出整数
if count > 0:
replay = np.random.choice(count, size=min(count, len(data) // 4), replace=False)
replay_data = np.empty((len(replay), 32, 13, 13))
for i, n in enumerate(replay):
f.seek(n * 32 * 13 * 13 + 4)
arr = np.frombuffer(f.read(32 * 13 * 13 * 4), dtype=np.float32).reshape(32, 13, 13)
replay_data[i] = arr
map_data: np.ndarray = replay_data.argmax(axis=1)
buf = bytearray()
buf.extend(struct.pack('>h', len(replay))) # Tensor count
buf.extend(struct.pack('>b', H)) # Map height
buf.extend(struct.pack('>b', W)) # Map width
buf.extend(map_data.astype(np.int8).tobytes()) # Map tensor
conn.sendall(buf)
data.extend(parse_minamo_data(conn, replay_data))
# 把新的内容写入文件末尾
to_write = np.random.choice(N, size=min(N, 100), replace=False)
write_data = bytearray()
for n in to_write:
write_data.extend(prob_list[n].tobytes())
f.seek(0, 2) # 定位到文件末尾
f.write(write_data)
f.seek(0) # 定位到文件开头
f.write(struct.pack('>i', count + len(to_write)))
f.flush() # 确保数据被刷新到磁盘
minamo_dataset.set_data(data)
# -------------------- 训练判别器
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
ginka.eval()
minamo.train()
total_loss = 0
for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"):
map1, map2, vis_sim, topo_sim, graph1, graph2 = parse_minamo_batch(batch)
batch_size = map1.shape[0]
if batch_size == 1:
continue
# 前向传播
optimizer_minamo.zero_grad()
vis_feat_real, topo_feat_real = minamo(map1, graph1)
vis_feat_ref, topo_feat_ref = minamo(map2, graph2)
# 生成假数据
with torch.no_grad():
fake_feat = torch.randn((batch_size, 1024), device=device)
fake_data = ginka(fake_feat)
# 创建插值样本
alpha = torch.rand((batch_size, 1, 1, 1), device=device)
interpolates = (alpha * map2 + (1 - alpha) * fake_data).requires_grad_(True)
vis_feat_fake, topo_feat_fake = minamo(fake_data)
vis_feat_interp, topo_feat_interp = minamo(interpolates)
vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1)
topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1)
vis_pred_fake = F.cosine_similarity(vis_feat_fake, vis_feat_ref, dim=1).unsqueeze(-1)
topo_pred_fake = F.cosine_similarity(topo_feat_fake, topo_feat_ref, dim=1).unsqueeze(-1)
vis_pred_interp = F.cosine_similarity(vis_feat_interp, vis_feat_ref, dim=1).unsqueeze(-1)
topo_pred_interp = F.cosine_similarity(topo_feat_interp, topo_feat_ref, dim=1).unsqueeze(-1)
# 计算相似度
score_real = F.l1_loss(vis_pred_real, vis_sim) * VISION_ALPHA + F.l1_loss(topo_pred_real, topo_sim) * (1 - VISION_ALPHA)
score_fake = vis_pred_fake * VISION_ALPHA + topo_pred_fake * (1 - VISION_ALPHA)
score_interp = vis_pred_interp * VISION_ALPHA + topo_pred_interp * (1 - VISION_ALPHA)
# 计算损失
loss = criterion.discriminator_loss(score_real, score_fake, score_interp)
# 反向传播
loss.backward()
optimizer_minamo.step()
total_loss += loss.item()
ave_loss = total_loss / len(minamo_dataloader)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_minamo.param_groups[0]['lr']):.6f}")
scheduler_minamo.step(epoch + 1)
# 每十轮推理一次验证集
if epoch + 1 == EPOCHS_MINAMO:
minamo.eval()
val_loss = 0
with torch.no_grad():
for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"):
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
vis_feat_real, topo_feat_real = minamo(map1_val, graph1)
vis_feat_ref, topo_feat_ref = minamo(map2_val, graph2)
vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1)
topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1)
# 计算损失
loss_val = criterion_minamo(vis_pred_real, topo_pred_real, vision_simi_val, topo_simi_val)
val_loss += loss_val.item()
avg_val_loss = val_loss / len(minamo_dataloader_val)
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
torch.save({
"model_state": minamo.state_dict()
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
tqdm.write(f"Cycle {cycle} Minamo train ended.")
torch.save({
"model_state": minamo.state_dict()
}, f"result/gan/minamo-{cycle}.pth")
torch.save({
"model_state": minamo.state_dict()
}, f"result/minamo.pth")
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
f.write(f' | Minamo: {avg_val_loss:.12f}\n')
print("Train ended.")
if __name__ == "__main__":
torch.set_num_threads(4)
train()

View File

@ -16,7 +16,7 @@ from shared.graph import batch_convert_soft_map_to_graph
from shared.image import matrix_to_image_cv
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
BATCH_SIZE = 32
BATCH_SIZE = 16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
@ -30,39 +30,56 @@ def parse_arguments():
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default="ginka-eval.json")
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5)
parser.add_argument("--load_optim", type=bool, default=True)
args = parser.parse_args()
return args
def clip_weights(model, clip_value=0.01):
for param in model.parameters():
param.data = torch.clamp(param.data, -clip_value, clip_value)
def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
fake1: torch.Tensor = gen(masked1, 1)
fake2: torch.Tensor = gen(masked2, 2)
fake3: torch.Tensor = gen(masked3, 3)
if detach:
return fake1.detach(), fake2.detach(), fake3.detach()
else:
return fake1, fake2, fake3
def gen_total(gen, input, detach=False) -> torch.Tensor:
fake1 = gen(input, 1)
fake2 = gen(fake1, 2)
fake3 = gen(fake2, 3)
if detach:
return fake3.detach()
else:
return fake3
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
args = parse_arguments()
# c_steps = 1 if args.resume else 5
# g_steps = 5 if args.resume else 1
c_steps = 5
g_steps = 1
# 1 代表课程学习阶段2 代表课程学习后,逐渐转为联合学习的阶段
# 3 代表课程学习后的联合遮挡学习阶段4 代表最后随机输入的联合学习阶段
train_stage = 1
mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
random_ratio = 0
ginka = GinkaModel()
minamo = MinamoScoreModule()
minamo_sim = MinamoSimilarityModel()
ginka.to(device)
minamo.to(device)
minamo_sim.to(device)
dataset = GinkaWGANDataset(args.train, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
dataset_val = GinkaWGANDataset(args.validate, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
optimizer_minamo_sim = optim.Adam(minamo_sim.parameters(), lr=1e-4)
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
@ -82,87 +99,134 @@ def train():
ginka.load_state_dict(data_ginka["model_state"], strict=False)
minamo.load_state_dict(data_minamo["model_state"], strict=False)
if data_ginka["c_steps"] is not None and data_ginka["g_steps"] is not None:
if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None:
c_steps = data_ginka["c_steps"]
g_steps = data_ginka["g_steps"]
if data_ginka.get("mask_ratio") is not None:
mask_ratio = data_ginka["mask_ratio"]
if data_ginka.get("random_ratio") is not None:
random_ratio = data_ginka["random_ratio"]
if data_ginka.get("stage") is not None:
train_stage = data_ginka["stage"]
if args.load_optim:
if data_ginka["optim_state"] is not None:
if data_ginka.get("optim_state") is not None:
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
if data_minamo["optim_state"] is not None:
if data_minamo.get("optim_state") is not None:
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
if data_minamo["optim_state_sim"] is not None:
optimizer_minamo_sim.load_state_dict(data_minamo["optim_state_sim"])
dataset.train_stage = train_stage
dataset.mask_ratio1 = mask_ratio
dataset.mask_ratio2 = mask_ratio
dataset.mask_ratio3 = mask_ratio
dataset.random_ratio = random_ratio
dataset_val.train_stage = train_stage
dataset_val.mask_ratio1 = mask_ratio
dataset_val.mask_ratio2 = mask_ratio
dataset_val.mask_ratio3 = mask_ratio
dataset_val.random_ratio = random_ratio
print("Train from loaded state.")
low_loss_epochs = 0
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
loss_total_minamo = torch.Tensor([0]).to(device)
loss_total_minamo_sim = torch.Tensor([0]).to(device)
loss_total_ginka = torch.Tensor([0]).to(device)
dis_total = torch.Tensor([0]).to(device)
loss_ce_total = torch.Tensor([0]).to(device)
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
batch_size = real_data.size(0)
real_data = real_data.to(device)
real_graph = batch_convert_soft_map_to_graph(real_data)
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
# ---------- 训练判别器
for _ in range(c_steps):
# 生成假样本
optimizer_minamo.zero_grad()
z = torch.rand(batch_size, 1024, device=device)
fake_data = ginka(z)
fake_data = fake_data.detach()
# 计算判别器输出
# 反向传播
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
loss_d.backward()
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=2.0)
# total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2)
# print("Critic 梯度范数:", total_norm.item())
# print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item())
# print("Critic 输出范围:", d_real.min().item(), d_real.max().item())
optimizer_ginka.zero_grad()
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
dis_avg = (dis1 + dis2 + dis3) / 3.0
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
# 反向传播
loss_d_avg.backward()
elif train_stage == 3:
pass
optimizer_minamo.step()
loss_total_minamo += loss_d.detach()
dis_total += dis.detach()
loss_total_minamo += loss_d_avg.detach()
dis_total += dis_avg.detach()
# ---------- 训练生成器
for _ in range(g_steps):
optimizer_minamo.zero_grad()
optimizer_ginka.zero_grad()
# optimizer_minamo_sim.zero_grad()
z1 = torch.randn(batch_size, 1024, device=device)
z2 = torch.randn(batch_size, 1024, device=device)
fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2)
# 先训练辅助判别器
# loss_c_assist = criterion.discriminator_loss_assist2(minamo_sim, real_data, fake_softmax1, fake_softmax2)
# loss_c_assist.backward(retain_graph=True)
# optimizer_minamo_sim.step()
loss_g = criterion.generator_loss(minamo, minamo_sim, fake_softmax1, fake_softmax2)
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g
# loss_total_minamo_sim += loss_c_assist.detach()
# tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}")
if train_stage == 1 or train_stage == 2:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False)
loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1)
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2)
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3)
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g.detach()
loss_ce_total += loss_ce.detach()
elif train_stage == 3:
pass
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
avg_loss_minamo_sim = loss_total_minamo_sim.item() / len(dataloader) / g_steps
avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps
avg_dis = dis_total.item() / len(dataloader) / c_steps
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +\
f"Epoch: {epoch + 1} | W Loss: {avg_dis:.8f} | " +\
f"G Loss: {avg_loss_ginka:.8f} | D Loss: {avg_loss_minamo:.8f} | " +\
f"lr G: {(optimizer_ginka.param_groups[0]['lr']):.8f}"
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
f"Epoch: {epoch + 1} | W: {avg_dis:.8f} | " +
f"G: {avg_loss_ginka:.8f} | D: {avg_loss_minamo:.8f} | " +
f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}"
)
if avg_loss_ce < 0.5:
low_loss_epochs += 1
else:
low_loss_epochs = 0
if low_loss_epochs >= 5 and train_stage == 2:
random_ratio += 0.1
random_ratio = min(random_ratio, 0.5)
low_loss_epochs = 0
if low_loss_epochs >= 5 and train_stage == 1:
if mask_ratio >= 0.9:
train_stage = 2
mask_ratio += 0.1
mask_ratio = min(mask_ratio, 0.9)
low_loss_epochs = 0
dataset.train_stage = 2
dataset_val.train_stage = 2
dataset.random_ratio = random_ratio
dataset_val.random_ratio = random_ratio
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
# scheduler_ginka.step()
# scheduler_minamo.step()
@ -172,38 +236,44 @@ def train():
g_steps = 1
if avg_loss_ginka > 0 or avg_loss_minamo > 0:
c_steps = min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15)
c_steps = int(max(min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15), 1))
else:
c_steps = 5
# 每若干轮输出一次图片,并保存检查点
if (epoch + 1) % args.checkpoint == 0:
# 输出 20 张图片,每批次 4 张,一共五批
idx = 0
with torch.no_grad():
for _ in range(5):
z = torch.randn(4, 1024, device=device)
output = ginka(z)
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
for matrix in map_matrix:
image = matrix_to_image_cv(matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
idx += 1
# 保存检查点
torch.save({
"model_state": ginka.state_dict(),
"optim_state": optimizer_ginka.state_dict(),
"c_steps": c_steps,
"g_steps": g_steps
"g_steps": g_steps,
"stage": train_stage,
"mask_ratio": mask_ratio,
"random_ratio": random_ratio,
}, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({
"model_state": minamo.state_dict(),
"model_state_sim": minamo_sim.state_dict(),
"optim_state": optimizer_minamo.state_dict(),
"optim_state_sim": optimizer_minamo_sim.state_dict()
"optim_state": optimizer_minamo.state_dict()
}, f"result/wgan/minamo-{epoch + 1}.pth")
idx = 0
with torch.no_grad():
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
if train_stage == 1:
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
for i in range(fake1.shape[0]):
for key, one in enumerate([fake1, fake2, fake3]):
map_matrix = one[i]
image = matrix_to_image_cv(map_matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
idx += 1
print("Train ended.")
torch.save({
@ -211,7 +281,6 @@ def train():
}, f"result/ginka.pth")
torch.save({
"model_state": minamo.state_dict(),
"model_state_sim": minamo_sim.state_dict(),
}, f"result/minamo.pth")
if __name__ == "__main__":

View File

@ -20,23 +20,41 @@ class MinamoModel(nn.Module):
return vision_feat, topo_feat
class MinamoScoreHead(nn.Module):
def __init__(self, in_dim, out_dim):
super().__init__()
self.vision_fc = nn.Sequential(
spectral_norm(nn.Linear(in_dim, out_dim)),
)
self.topo_fc = nn.Sequential(
spectral_norm(nn.Linear(in_dim, out_dim))
)
def forward(self, vis_feat, topo_feat):
vis_score = self.vision_fc(vis_feat)
topo_score = self.topo_fc(topo_feat)
return vis_score, topo_score
class MinamoScoreModule(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
self.topo_model = MinamoTopoModel(tile_types)
self.vision_model = MinamoVisionModel(tile_types)
# 输出层
self.topo_fc = nn.Sequential(
spectral_norm(nn.Linear(512, 1)),
)
self.vision_fc = nn.Sequential(
spectral_norm(nn.Linear(512, 1)),
)
self.head1 = MinamoScoreHead(512, 1)
self.head2 = MinamoScoreHead(512, 1)
self.head3 = MinamoScoreHead(512, 1)
def forward(self, map, graph):
topo_feat = self.topo_model(graph)
topo_score = self.topo_fc(topo_feat)
def forward(self, map, graph, stage):
vision_feat = self.vision_model(map)
vision_score = self.vision_fc(vision_feat)
topo_feat = self.topo_model(graph)
if stage == 1:
vision_score, topo_score = self.head1(vision_feat, topo_feat)
elif stage == 2:
vision_score, topo_score = self.head2(vision_feat, topo_feat)
elif stage == 3:
vision_score, topo_score = self.head3(vision_feat, topo_feat)
else:
raise RuntimeError("Unknown critic stage.")
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
return score, vision_score, topo_score

4
train.sh Normal file
View File

@ -0,0 +1,4 @@
# 从头训练
python3 -u -m ginka.train_wgan >> output.log
# 接续训练
python3 -u -m ginka.train_wgan --resume true --state_ginka "result/wgan/ginka-100.pth" --state_minamo "result/wgan/minamo-100.pth" >> output.log