mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 19:31:12 +08:00
refactor: 课程训练
This commit is contained in:
parent
99f46150be
commit
f6b1ad6ebd
127
ginka/dataset.py
127
ginka/dataset.py
@ -3,10 +3,18 @@ import random
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from minamo.model.model import MinamoModel
|
import torch
|
||||||
from shared.graph import differentiable_convert_to_data
|
import torch.nn.functional as F
|
||||||
|
from typing import List
|
||||||
from shared.utils import random_smooth_onehot
|
from shared.utils import random_smooth_onehot
|
||||||
|
|
||||||
|
STAGE1_MASK = [0, 1, 10, 11]
|
||||||
|
STAGE1_REMOVE = [2, 3, 4, 5, 6, 7, 8, 9, 12]
|
||||||
|
STAGE2_MASK = [6, 7, 8, 9]
|
||||||
|
STAGE2_REMOVE = [2, 3, 4, 5, 12]
|
||||||
|
STAGE3_MASK = [2, 3, 4, 5, 12]
|
||||||
|
STAGE3_REMOVE = []
|
||||||
|
|
||||||
def load_data(path: str):
|
def load_data(path: str):
|
||||||
with open(path, 'r', encoding="utf-8") as f:
|
with open(path, 'r', encoding="utf-8") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
@ -23,38 +31,45 @@ def load_minamo_gan_data(data: list):
|
|||||||
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
|
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
|
||||||
return res
|
return res
|
||||||
|
|
||||||
class GinkaDataset(Dataset):
|
def apply_curriculum_mask(
|
||||||
def __init__(self, data_path: str, device, minamo: MinamoModel):
|
maps: torch.Tensor, # [B, C, H, W]
|
||||||
self.data = load_data(data_path) # 自定义数据加载函数
|
mask_classes: List[int], # 要遮挡的类别索引
|
||||||
self.max_size = 32
|
remove_classes: List[int], # 要移除的类别索引
|
||||||
self.minamo = minamo
|
mask_ratio: float # 遮挡比例 0~1
|
||||||
self.device = device
|
) -> torch.Tensor:
|
||||||
|
C, H, W = maps.shape
|
||||||
|
device = maps.device
|
||||||
|
masked_maps = maps.clone()
|
||||||
|
|
||||||
def __len__(self):
|
# Step 1: 移除不需要的类别(全设为 0 类)
|
||||||
return len(self.data)
|
if remove_classes:
|
||||||
|
remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
|
||||||
|
masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
|
||||||
|
masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
removed_maps = masked_maps.clone()
|
||||||
item = self.data[idx]
|
|
||||||
|
|
||||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
# Step 2: 对指定类别随机遮挡
|
||||||
min_main = random.uniform(0.75, 0.9)
|
for cls in mask_classes:
|
||||||
max_main = random.uniform(0.9, 1)
|
cls_mask = masked_maps[:, cls] > 0 # 目标类别的像素布尔掩码 [H, W]
|
||||||
epsilon = random.uniform(0, 0.25)
|
indices = cls_mask.nonzero(as_tuple=False) # 所有该类像素坐标
|
||||||
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon)
|
num_mask = int(len(indices) * mask_ratio)
|
||||||
graph = differentiable_convert_to_data(target_smooth).to(self.device)
|
if num_mask > 0:
|
||||||
target = target.to(self.device)
|
selected = indices[torch.randperm(len(indices))[:num_mask]]
|
||||||
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
masked_maps[cls, selected[:, 0], selected[:, 1]] = 0
|
||||||
|
masked_maps[0, selected[:, 0], selected[:, 1]] = 1 # 置为“空地”
|
||||||
|
|
||||||
return {
|
return removed_maps, masked_maps
|
||||||
"target_vision_feat": vision_feat,
|
|
||||||
"target_topo_feat": topo_feat,
|
|
||||||
"target": target,
|
|
||||||
}
|
|
||||||
|
|
||||||
class GinkaWGANDataset(Dataset):
|
class GinkaWGANDataset(Dataset):
|
||||||
def __init__(self, data_path: str, device):
|
def __init__(self, data_path: str, device):
|
||||||
self.data = load_data(data_path) # 自定义数据加载函数
|
self.data = load_data(data_path) # 自定义数据加载函数
|
||||||
self.device = device
|
self.device = device
|
||||||
|
self.train_stage = 1
|
||||||
|
self.mask_ratio1 = 0.1
|
||||||
|
self.mask_ratio2 = 0.1
|
||||||
|
self.mask_ratio3 = 0.1
|
||||||
|
self.random_ratio = 0.0
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data)
|
return len(self.data)
|
||||||
@ -63,56 +78,20 @@ class GinkaWGANDataset(Dataset):
|
|||||||
item = self.data[idx]
|
item = self.data[idx]
|
||||||
|
|
||||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
# min_main = random.uniform(0.8, 0.9)
|
|
||||||
# max_main = random.uniform(0.9, 1)
|
|
||||||
# epsilon = random.uniform(0, 0.2)
|
|
||||||
# target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device)
|
|
||||||
|
|
||||||
return target
|
if self.train_stage == 1:
|
||||||
|
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
|
||||||
|
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
|
||||||
|
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
|
||||||
|
elif self.train_stage == 2:
|
||||||
|
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||||
|
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 0.9))
|
||||||
|
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 0.9))
|
||||||
|
|
||||||
class MinamoGANDataset(Dataset):
|
if self.random_ratio > 0:
|
||||||
def __init__(self, refer_data_path):
|
removed1 = random_smooth_onehot(removed1, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
||||||
self.refer = load_minamo_gan_data(load_data(refer_data_path))
|
removed2 = random_smooth_onehot(removed2, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
||||||
self.data = list()
|
removed3 = random_smooth_onehot(removed3, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
||||||
self.data.extend(random.sample(self.refer, 1000))
|
|
||||||
|
|
||||||
def set_data(self, data: list):
|
return removed1, masked1, removed2, masked2, removed3, masked3
|
||||||
self.data.clear()
|
|
||||||
self.data.extend(data)
|
|
||||||
k = min(len(data) / 4, len(self.refer))
|
|
||||||
self.data.extend(random.sample(self.refer, int(k)))
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.data)
|
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
|
||||||
# 假定 map2 是参考地图
|
|
||||||
item = self.data[idx]
|
|
||||||
|
|
||||||
map1, map2, vis_sim, topo_sim, review = item
|
|
||||||
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
|
|
||||||
if review:
|
|
||||||
map1 = F.one_hot(torch.LongTensor(map1), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
|
||||||
else:
|
|
||||||
map1 = torch.FloatTensor(map1)
|
|
||||||
map2 = F.one_hot(torch.LongTensor(map2), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
|
||||||
|
|
||||||
min_main = random.uniform(0.75, 0.9)
|
|
||||||
max_main = random.uniform(0.9, 1)
|
|
||||||
epsilon = random.uniform(0, 0.25)
|
|
||||||
|
|
||||||
if review:
|
|
||||||
map1 = random_smooth_onehot(map1, min_main, max_main, epsilon)
|
|
||||||
map2 = random_smooth_onehot(map2, min_main, max_main, epsilon)
|
|
||||||
|
|
||||||
graph1 = differentiable_convert_to_data(map1)
|
|
||||||
graph2 = differentiable_convert_to_data(map2)
|
|
||||||
|
|
||||||
return (
|
|
||||||
map1,
|
|
||||||
map2,
|
|
||||||
torch.FloatTensor([vis_sim]),
|
|
||||||
torch.FloatTensor([topo_sim]),
|
|
||||||
graph1,
|
|
||||||
graph2
|
|
||||||
)
|
|
||||||
@ -2,43 +2,24 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
class GinkaInput(nn.Module):
|
class GinkaInput(nn.Module):
|
||||||
def __init__(self, feat_dim=1024, out_ch=1, size=(32, 32)):
|
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.out_size = out_size
|
||||||
self.fc = nn.Sequential(
|
self.fc = nn.Sequential(
|
||||||
nn.Linear(feat_dim, size[0] * size[1] * out_ch),
|
nn.Linear(in_size[0] * in_size[1], out_size[0] * out_size[1]),
|
||||||
nn.Unflatten(1, (out_ch, *size))
|
nn.LayerNorm(out_size[0] * out_size[1]),
|
||||||
|
nn.ELU()
|
||||||
|
)
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||||
|
nn.InstanceNorm2d(out_ch),
|
||||||
|
nn.ELU()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
B, C, H, W = x.shape
|
||||||
|
x = x.view(B, C, H * W)
|
||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
|
x = x.view(B, C, self.out_size[0], self.out_size[1])
|
||||||
|
x = self.conv(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class FeatureEncoder(nn.Module):
|
|
||||||
def __init__(self, feat_dim, size, mid_ch, out_ch):
|
|
||||||
super().__init__()
|
|
||||||
self.encode = nn.Sequential(
|
|
||||||
nn.Linear(feat_dim, mid_ch * size * size),
|
|
||||||
nn.Unflatten(1, (mid_ch, size, size)),
|
|
||||||
nn.Conv2d(mid_ch, out_ch, 1)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x = self.encode(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
class GinkaFeatureInput(nn.Module):
|
|
||||||
def __init__(self, feat_dim=1024, mid_ch=1, out_ch=64):
|
|
||||||
super().__init__()
|
|
||||||
self.encode1 = FeatureEncoder(feat_dim, 32, mid_ch, out_ch)
|
|
||||||
self.encode2 = FeatureEncoder(feat_dim, 16, mid_ch * 2, out_ch * 2)
|
|
||||||
self.encode3 = FeatureEncoder(feat_dim, 8, mid_ch * 4, out_ch * 4)
|
|
||||||
self.encode4 = FeatureEncoder(feat_dim, 4, mid_ch * 8, out_ch * 8)
|
|
||||||
self.encode5 = FeatureEncoder(feat_dim, 2, mid_ch * 16, out_ch * 16)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
x1 = self.encode1(x)
|
|
||||||
x2 = self.encode2(x)
|
|
||||||
x3 = self.encode3(x)
|
|
||||||
x4 = self.encode4(x)
|
|
||||||
x5 = self.encode5(x)
|
|
||||||
return x1, x2, x3, x4, x5
|
|
||||||
|
|||||||
@ -13,6 +13,13 @@ from shared.similarity.vision import calculate_visual_similarity
|
|||||||
CLASS_NUM = 32
|
CLASS_NUM = 32
|
||||||
ILLEGAL_MAX_NUM = 12
|
ILLEGAL_MAX_NUM = 12
|
||||||
|
|
||||||
|
STAGE_ALLOWED = [
|
||||||
|
[],
|
||||||
|
[0, 1, 10, 11],
|
||||||
|
[6, 7, 8, 9,],
|
||||||
|
[2, 3, 4, 5, 12]
|
||||||
|
]
|
||||||
|
|
||||||
def get_not_allowed(classes: list[int], include_illegal=False):
|
def get_not_allowed(classes: list[int], include_illegal=False):
|
||||||
res = list()
|
res = list()
|
||||||
for num in range(0, CLASS_NUM):
|
for num in range(0, CLASS_NUM):
|
||||||
@ -301,24 +308,47 @@ def js_divergence(p, q, eps=1e-8):
|
|||||||
|
|
||||||
return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0)
|
return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0)
|
||||||
|
|
||||||
|
def immutable_penalty_loss(
|
||||||
|
pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
惩罚模型修改不可更改区域的损失。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: 模型输出 [B, C, H, W],概率分布 (softmax 后)
|
||||||
|
target: 原始输入图 [B, C, H, W],概率分布 (softmax 后)
|
||||||
|
modifiable_classes: 允许被修改的类别列表
|
||||||
|
penalty_weight: 对非允许修改区域的惩罚系数
|
||||||
|
"""
|
||||||
|
not_allowed = get_not_allowed(modifiable_classes, include_illegal=True)
|
||||||
|
input_mask = pred[:, not_allowed, :, :]
|
||||||
|
with torch.no_grad():
|
||||||
|
target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1)
|
||||||
|
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
|
||||||
|
|
||||||
|
# 差异区域(模型试图改变的地方)
|
||||||
|
penalty = F.cross_entropy(input_mask, target_mask)
|
||||||
|
|
||||||
|
return penalty
|
||||||
|
|
||||||
class WGANGinkaLoss:
|
class WGANGinkaLoss:
|
||||||
def __init__(self, lambda_gp=100, weight=[0.8, 0.1, 0.1], diversity_lamda=0.4):
|
def __init__(self, lambda_gp=100, weight=[1, 0.4, 10, 0.2, 0.2]):
|
||||||
|
# weight: 判别器损失,L1 损失,不可修改类型损失
|
||||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.diversity_lamda = diversity_lamda
|
|
||||||
|
|
||||||
def compute_gradient_penalty(self, critic, real_data, fake_data):
|
def compute_gradient_penalty(self, critic, stage, real_data, fake_data):
|
||||||
# 进行插值
|
# 进行插值
|
||||||
batch_size = real_data.size(0)
|
batch_size = real_data.size(0)
|
||||||
epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device)
|
epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device)
|
||||||
interp_data = interpolate_data(real_data, fake_data, epsilon_data)
|
interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device)
|
||||||
interp_graph = batch_convert_soft_map_to_graph(interp_data)
|
interp_graph = batch_convert_soft_map_to_graph(interp_data).to(real_data.device)
|
||||||
|
|
||||||
# 对图像进行反向传播并计算梯度
|
# 对图像进行反向传播并计算梯度
|
||||||
interp_data.requires_grad_()
|
interp_data.requires_grad_()
|
||||||
interp_graph.x.requires_grad_()
|
interp_graph.x.requires_grad_()
|
||||||
|
|
||||||
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph)
|
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage)
|
||||||
|
|
||||||
# 计算梯度
|
# 计算梯度
|
||||||
grad_vis = torch.autograd.grad(
|
grad_vis = torch.autograd.grad(
|
||||||
@ -344,21 +374,21 @@ class WGANGinkaLoss:
|
|||||||
return gp_loss
|
return gp_loss
|
||||||
|
|
||||||
def discriminator_loss(
|
def discriminator_loss(
|
||||||
self, critic, real_data: torch.Tensor,
|
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor
|
||||||
real_graph: torch.Tensor, fake_data: torch.Tensor
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
):
|
|
||||||
""" 判别器损失函数 """
|
""" 判别器损失函数 """
|
||||||
|
real_graph = batch_convert_soft_map_to_graph(real_data)
|
||||||
fake_graph = batch_convert_soft_map_to_graph(fake_data)
|
fake_graph = batch_convert_soft_map_to_graph(fake_data)
|
||||||
real_scores, _, _ = critic(real_data, real_graph)
|
real_scores, _, _ = critic(real_data, real_graph, stage)
|
||||||
fake_scores, _, _ = critic(fake_data, fake_graph)
|
fake_scores, _, _ = critic(fake_data, fake_graph, stage)
|
||||||
|
|
||||||
# print("Critic 输出范围", fake_scores.min().item(), fake_scores.max().item(), real_scores.min().item(), real_scores.max().item())
|
|
||||||
|
|
||||||
# Wasserstein 距离
|
# Wasserstein 距离
|
||||||
d_loss = fake_scores.mean() - real_scores.mean()
|
d_loss = fake_scores.mean() - real_scores.mean()
|
||||||
grad_loss = self.compute_gradient_penalty(critic, real_data, fake_data)
|
grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data)
|
||||||
|
|
||||||
return d_loss, d_loss + self.lambda_gp * grad_loss
|
total_loss = d_loss + self.lambda_gp * grad_loss
|
||||||
|
|
||||||
|
return total_loss, d_loss
|
||||||
|
|
||||||
def calculate_similarity_one(self, map1, map2):
|
def calculate_similarity_one(self, map1, map2):
|
||||||
topo1 = build_topological_graph(map1)
|
topo1 = build_topological_graph(map1)
|
||||||
@ -369,72 +399,28 @@ class WGANGinkaLoss:
|
|||||||
|
|
||||||
return vis_sim, topo_sim
|
return vis_sim, topo_sim
|
||||||
|
|
||||||
def discriminator_loss_assist(self, critic, fake_data1, fake_data2):
|
def generator_loss(self, critic, stage, mask_ratio, real, fake, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
graph1 = batch_convert_soft_map_to_graph(fake_data1)
|
""" 生成器损失函数 """
|
||||||
graph2 = batch_convert_soft_map_to_graph(fake_data2)
|
fake_graph = batch_convert_soft_map_to_graph(fake)
|
||||||
vis_feat_1, topo_feat_1 = critic(fake_data1, graph1)
|
|
||||||
vis_feat_2, topo_feat_2 = critic(fake_data2, graph2)
|
|
||||||
|
|
||||||
batch1 = torch.argmax(fake_data1, dim=1).cpu().tolist()
|
fake_scores, _, _ = critic(fake, fake_graph, stage)
|
||||||
batch2 = torch.argmax(fake_data2, dim=1).cpu().tolist()
|
|
||||||
|
|
||||||
vis_sim_real = []
|
|
||||||
topo_sim_real = []
|
|
||||||
|
|
||||||
for i in range(len(batch1)):
|
|
||||||
vis_sim, topo_sim = self.calculate_similarity_one(batch1[i], batch2[i])
|
|
||||||
vis_sim_real.append(vis_sim)
|
|
||||||
topo_sim_real.append(topo_sim)
|
|
||||||
|
|
||||||
vis_sim_real = torch.Tensor(vis_sim_real)
|
|
||||||
topo_sim_real = torch.Tensor(topo_sim_real)
|
|
||||||
|
|
||||||
pred_vis_sim = F.cosine_similarity(vis_feat_1, vis_feat_2).cpu()
|
|
||||||
pred_topo_sim = F.cosine_similarity(topo_feat_1, topo_feat_2).cpu()
|
|
||||||
|
|
||||||
loss1 = F.l1_loss(pred_vis_sim, vis_sim_real) * VISION_WEIGHT + F.l1_loss(pred_topo_sim, topo_sim_real) * TOPO_WEIGHT
|
|
||||||
|
|
||||||
return loss1
|
|
||||||
|
|
||||||
def discriminator_loss_assist2(self, critic, real_data, fake_data1, fake_data2):
|
|
||||||
loss1 = self.discriminator_loss_assist(critic, real_data, fake_data1)
|
|
||||||
loss2 = self.discriminator_loss_assist(critic, real_data, fake_data2)
|
|
||||||
loss3 = self.discriminator_loss_assist(critic, fake_data1, fake_data2)
|
|
||||||
|
|
||||||
return loss1 / 3.0 + loss2 / 3.0 + loss3 / 3.0
|
|
||||||
|
|
||||||
def generator_loss_one(self, critic, fake, fake_graph):
|
|
||||||
fake_scores, _, _ = critic(fake, fake_graph)
|
|
||||||
minamo_loss = -torch.mean(fake_scores)
|
minamo_loss = -torch.mean(fake_scores)
|
||||||
class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
ce_loss = F.cross_entropy(fake, real)
|
||||||
entrance_loss = entrance_constraint_loss(fake)
|
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||||
|
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||||
|
|
||||||
losses = [
|
losses = [
|
||||||
minamo_loss * self.weight[0],
|
minamo_loss * self.weight[0],
|
||||||
class_loss * self.weight[1],
|
ce_loss * self.weight[1] / mask_ratio * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
|
||||||
entrance_loss * self.weight[2]
|
immutable_loss * self.weight[2],
|
||||||
|
constraint_loss * self.weight[3]
|
||||||
]
|
]
|
||||||
|
|
||||||
return sum(losses)
|
if stage == 1:
|
||||||
|
# 第一个阶段检查入口存在性
|
||||||
|
entrance_loss = entrance_constraint_loss(fake)
|
||||||
|
losses.append(entrance_loss * self.weight[4])
|
||||||
|
|
||||||
def generator_loss(self, critic, critic_assist, fake1, fake2):
|
# print(losses[2].item())
|
||||||
""" 生成器损失函数 """
|
|
||||||
fake_graph1 = batch_convert_soft_map_to_graph(fake1)
|
|
||||||
fake_graph2 = batch_convert_soft_map_to_graph(fake2)
|
|
||||||
|
|
||||||
loss1 = self.generator_loss_one(critic, fake1, fake_graph1)
|
return sum(losses), minamo_loss, ce_loss / mask_ratio, immutable_loss
|
||||||
loss2 = self.generator_loss_one(critic, fake2, fake_graph2)
|
|
||||||
|
|
||||||
# vis_feat1, topo_feat1 = critic_assist(fake1, fake_graph1)
|
|
||||||
# vis_feat2, topo_feat2 = critic_assist(fake2, fake_graph2)
|
|
||||||
|
|
||||||
# vis_sim = F.cosine_similarity(vis_feat1, vis_feat2)
|
|
||||||
# topo_sim = F.cosine_similarity(topo_feat1, topo_feat2)
|
|
||||||
# similarity = vis_sim * VISION_WEIGHT + topo_sim * TOPO_WEIGHT
|
|
||||||
|
|
||||||
# print(similarity.mean().item())
|
|
||||||
# div_loss = F.l1_loss(fake1[:, :, 1:-1, 1:-1], fake2[:, :, 1:-1, 1:-1])
|
|
||||||
|
|
||||||
return loss1 * 0.5 + loss2 * 0.5\
|
|
||||||
# + self.diversity_lamda * F.relu(0.7 - div_loss).mean()
|
|
||||||
# + self.diversity_lamda * F.relu(similarity - 0.4).mean()
|
|
||||||
|
|||||||
@ -9,27 +9,29 @@ def print_memory(tag=""):
|
|||||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||||
|
|
||||||
class GinkaModel(nn.Module):
|
class GinkaModel(nn.Module):
|
||||||
def __init__(self, feat_dim=1024, base_ch=64, out_ch=32):
|
def __init__(self, base_ch=64, out_ch=32):
|
||||||
"""Ginka Model 模型定义部分
|
"""Ginka Model 模型定义部分
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.unet = GinkaUNet(base_ch, base_ch, feat_dim)
|
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
|
||||||
|
self.unet = GinkaUNet(32, base_ch, base_ch)
|
||||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, stage):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: 参考地图的特征向量
|
x: 参考地图的特征向量
|
||||||
Returns:
|
Returns:
|
||||||
logits: 输出logits [BS, num_classes, H, W]
|
logits: 输出logits [BS, num_classes, H, W]
|
||||||
"""
|
"""
|
||||||
|
x = self.input(x)
|
||||||
x = self.unet(x)
|
x = self.unet(x)
|
||||||
x = self.output(x)
|
x = self.output(x, stage)
|
||||||
return F.softmax(x, dim=1)
|
return F.softmax(x, dim=1)
|
||||||
|
|
||||||
# 检查显存占用
|
# 检查显存占用
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
feat = torch.randn((1, 1024)).cuda()
|
input = torch.randn((1, 32, 13, 13)).cuda()
|
||||||
|
|
||||||
# 初始化模型
|
# 初始化模型
|
||||||
model = GinkaModel().cuda()
|
model = GinkaModel().cuda()
|
||||||
@ -37,14 +39,13 @@ if __name__ == "__main__":
|
|||||||
print_memory("初始化后")
|
print_memory("初始化后")
|
||||||
|
|
||||||
# 前向传播
|
# 前向传播
|
||||||
output = model(feat)
|
output = model(input, 1)
|
||||||
|
|
||||||
print_memory("前向传播后")
|
print_memory("前向传播后")
|
||||||
|
|
||||||
print(f"输入形状: feat={feat.shape}")
|
print(f"输入形状: feat={input.shape}")
|
||||||
print(f"输出形状: output={output.shape}")
|
print(f"输出形状: output={output.shape}")
|
||||||
# print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
||||||
# print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
|
|
||||||
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
|
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
|
||||||
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
|
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
|
||||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||||
|
|||||||
@ -1,13 +1,42 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
class GinkaOutput(nn.Module):
|
class StageHead(nn.Module):
|
||||||
def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)):
|
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.conv_down = nn.Sequential(
|
self.head = nn.Sequential(
|
||||||
|
nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'),
|
||||||
|
nn.InstanceNorm2d(in_ch),
|
||||||
|
nn.ELU(),
|
||||||
|
|
||||||
|
nn.Conv2d(in_ch, in_ch, 1),
|
||||||
|
nn.InstanceNorm2d(in_ch),
|
||||||
|
nn.ELU(),
|
||||||
|
)
|
||||||
|
self.pool = nn.Sequential(
|
||||||
nn.AdaptiveMaxPool2d(out_size),
|
nn.AdaptiveMaxPool2d(out_size),
|
||||||
nn.Conv2d(in_ch, out_ch, 1)
|
nn.Conv2d(in_ch, out_ch, 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.conv_down(x)
|
x = self.head(x)
|
||||||
|
x = self.pool(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
class GinkaOutput(nn.Module):
|
||||||
|
def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)):
|
||||||
|
super().__init__()
|
||||||
|
self.head1 = StageHead(in_ch, out_ch, out_size)
|
||||||
|
self.head2 = StageHead(in_ch, out_ch, out_size)
|
||||||
|
self.head3 = StageHead(in_ch, out_ch, out_size)
|
||||||
|
|
||||||
|
def forward(self, x, stage):
|
||||||
|
if stage == 1:
|
||||||
|
x = self.head1(x)
|
||||||
|
elif stage == 2:
|
||||||
|
x = self.head2(x)
|
||||||
|
elif stage == 3:
|
||||||
|
x = self.head3(x)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown generate stage.")
|
||||||
|
return x
|
||||||
|
|||||||
@ -198,15 +198,15 @@ class GinkaBottleneck(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
class GinkaUNet(nn.Module):
|
class GinkaUNet(nn.Module):
|
||||||
def __init__(self, base_ch=64, out_ch=32, feat_dim=1024):
|
def __init__(self, in_ch=32, base_ch=64, out_ch=32):
|
||||||
"""Ginka Model UNet 部分
|
"""Ginka Model UNet 部分
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input = GinkaTransformerEncoder(
|
# self.input = GinkaTransformerEncoder(
|
||||||
in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
|
# in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
|
||||||
token_size=4, ff_dim=feat_dim*2, num_layers=4
|
# token_size=4, ff_dim=feat_dim*2, num_layers=4
|
||||||
)
|
# )
|
||||||
self.down1 = ConvBlock(2, base_ch)
|
self.down1 = ConvBlock(in_ch, base_ch)
|
||||||
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
|
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
|
||||||
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
|
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
|
||||||
self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
|
self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
|
||||||
@ -223,10 +223,6 @@ class GinkaUNet(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, D = x.shape # [B, 1024]
|
|
||||||
x = x.view(B, 4, D // 4) # [B, 4, 256]
|
|
||||||
x = self.input(x) # [B, 4, 512]
|
|
||||||
x = x.view(B, 2, 32, 32) # [B, 2, 32, 32]
|
|
||||||
x1 = self.down1(x) # [B, 64, 32, 32]
|
x1 = self.down1(x) # [B, 64, 32, 32]
|
||||||
x2 = self.down2(x1) # [B, 128, 16, 16]
|
x2 = self.down2(x1) # [B, 128, 16, 16]
|
||||||
x3 = self.down3(x2) # [B, 256, 8, 8]
|
x3 = self.down3(x2) # [B, 256, 8, 8]
|
||||||
@ -237,5 +233,6 @@ class GinkaUNet(nn.Module):
|
|||||||
x = self.up1(x4, x3) # [B, 256, 8, 8]
|
x = self.up1(x4, x3) # [B, 256, 8, 8]
|
||||||
x = self.up2(x, x2) # [B, 128, 16, 16]
|
x = self.up2(x, x2) # [B, 128, 16, 16]
|
||||||
x = self.up3(x, x1) # [B, 64, 32, 32]
|
x = self.up3(x, x1) # [B, 64, 32, 32]
|
||||||
|
x = self.final(x) # [B, 32, 32, 32]
|
||||||
|
|
||||||
return self.final(x) # [B, 32, 32, 32]
|
return x
|
||||||
|
|||||||
134
ginka/train.py
134
ginka/train.py
@ -1,134 +0,0 @@
|
|||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
import torch
|
|
||||||
import torch.optim as optim
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from tqdm import tqdm
|
|
||||||
from .model.model import GinkaModel
|
|
||||||
from .model.loss import GinkaLoss
|
|
||||||
from .dataset import GinkaDataset
|
|
||||||
from minamo.model.model import MinamoModel
|
|
||||||
from shared.args import parse_arguments
|
|
||||||
|
|
||||||
BATCH_SIZE = 32
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
os.makedirs("result", exist_ok=True)
|
|
||||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
|
||||||
|
|
||||||
def train():
|
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
|
||||||
|
|
||||||
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
|
|
||||||
|
|
||||||
model = GinkaModel()
|
|
||||||
model.to(device)
|
|
||||||
minamo = MinamoModel(32)
|
|
||||||
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
|
||||||
minamo.to(device)
|
|
||||||
minamo.eval()
|
|
||||||
|
|
||||||
# 准备数据集
|
|
||||||
dataset = GinkaDataset(args.train, device, minamo)
|
|
||||||
dataset_val = GinkaDataset(args.validate, device, minamo)
|
|
||||||
dataloader = DataLoader(
|
|
||||||
dataset,
|
|
||||||
batch_size=BATCH_SIZE,
|
|
||||||
shuffle=True
|
|
||||||
)
|
|
||||||
dataloader_val = DataLoader(
|
|
||||||
dataset_val,
|
|
||||||
batch_size=BATCH_SIZE,
|
|
||||||
shuffle=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# 设定优化器与调度器
|
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
|
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
|
||||||
criterion = GinkaLoss(minamo)
|
|
||||||
|
|
||||||
if args.resume:
|
|
||||||
data = torch.load(args.from_state, map_location=device)
|
|
||||||
model.load_state_dict(data["model_state"], strict=False)
|
|
||||||
if args.load_optim:
|
|
||||||
optimizer.load_state_dict(data["optimizer_state"])
|
|
||||||
print("Train from loaded state.")
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
|
|
||||||
criterion.weight[0] = 0.0
|
|
||||||
|
|
||||||
# 开始训练
|
|
||||||
for epoch in tqdm(range(args.epochs)):
|
|
||||||
model.train()
|
|
||||||
total_loss = 0
|
|
||||||
|
|
||||||
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
|
|
||||||
if not args.resume and epoch == 10:
|
|
||||||
criterion.weight[0] = 0.5
|
|
||||||
|
|
||||||
for batch in dataloader:
|
|
||||||
# 数据迁移到设备
|
|
||||||
target = batch["target"].to(device)
|
|
||||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
|
||||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
|
||||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
|
||||||
# 前向传播
|
|
||||||
optimizer.zero_grad()
|
|
||||||
_, output_softmax = model(feat_vec)
|
|
||||||
|
|
||||||
# 计算损失
|
|
||||||
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
|
||||||
|
|
||||||
# 反向传播
|
|
||||||
losses.backward()
|
|
||||||
optimizer.step()
|
|
||||||
total_loss += losses.item()
|
|
||||||
|
|
||||||
avg_loss = total_loss / len(dataloader)
|
|
||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
|
||||||
|
|
||||||
# for name, param in model.named_parameters():
|
|
||||||
# if param.grad is not None:
|
|
||||||
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
|
|
||||||
|
|
||||||
# 学习率调整
|
|
||||||
scheduler.step()
|
|
||||||
|
|
||||||
if (epoch + 1) % 5 == 0:
|
|
||||||
loss_val = 0
|
|
||||||
model.eval()
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch in dataloader_val:
|
|
||||||
# 数据迁移到设备
|
|
||||||
target = batch["target"].to(device)
|
|
||||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
|
||||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
|
||||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
output, output_softmax = model(feat_vec)
|
|
||||||
print(torch.argmax(output, dim=1)[0])
|
|
||||||
|
|
||||||
# 计算损失
|
|
||||||
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
|
||||||
loss_val += losses.item()
|
|
||||||
|
|
||||||
avg_val_loss = loss_val / len(dataloader_val)
|
|
||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
|
||||||
torch.save({
|
|
||||||
"model_state": model.state_dict(),
|
|
||||||
# "optimizer_state": optimizer.state_dict(),
|
|
||||||
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
|
||||||
|
|
||||||
|
|
||||||
print("Train ended.")
|
|
||||||
|
|
||||||
torch.save({
|
|
||||||
"model_state": model.state_dict(),
|
|
||||||
# "optimizer_state": optimizer.state_dict(),
|
|
||||||
}, f"result/ginka.pth")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
torch.set_num_threads(4)
|
|
||||||
train()
|
|
||||||
@ -1,410 +0,0 @@
|
|||||||
import argparse
|
|
||||||
import socket
|
|
||||||
import struct
|
|
||||||
import os
|
|
||||||
from datetime import datetime
|
|
||||||
import torch
|
|
||||||
import torch.optim as optim
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch_geometric.loader import DataLoader
|
|
||||||
from tqdm import tqdm
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
from .model.model import GinkaModel
|
|
||||||
from .model.loss import GinkaLoss, WGANGinkaLoss
|
|
||||||
from .dataset import GinkaDataset, MinamoGANDataset
|
|
||||||
from minamo.model.model import MinamoModel
|
|
||||||
from minamo.model.loss import MinamoLoss
|
|
||||||
from shared.image import matrix_to_image_cv
|
|
||||||
|
|
||||||
BATCH_SIZE = 32
|
|
||||||
EPOCHS_GINKA = 5
|
|
||||||
EPOCHS_MINAMO = 2
|
|
||||||
SOCKET_PATH = "./tmp/ginka_uds"
|
|
||||||
LOSS_PATH = "result/gan/a-loss.txt"
|
|
||||||
REPLAY_PATH = "datasets/replay.bin"
|
|
||||||
VISION_ALPHA = 0
|
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
os.makedirs("result", exist_ok=True)
|
|
||||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
|
||||||
os.makedirs("result/gan", exist_ok=True)
|
|
||||||
os.makedirs("tmp", exist_ok=True)
|
|
||||||
|
|
||||||
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
|
||||||
f.write(f"---------- {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ----------\n")
|
|
||||||
|
|
||||||
if not os.path.exists(REPLAY_PATH):
|
|
||||||
with open(REPLAY_PATH, 'wb') as f:
|
|
||||||
f.write(b'\x00\x00\x00\x00')
|
|
||||||
|
|
||||||
def parse_arguments():
|
|
||||||
parser = argparse.ArgumentParser(description="training codes")
|
|
||||||
parser.add_argument("--resume", type=bool, default=False)
|
|
||||||
parser.add_argument("--from_state", type=str, default="result/ginka.pth")
|
|
||||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
|
||||||
parser.add_argument("--validate", type=str, default='ginka-eval.json')
|
|
||||||
parser.add_argument("--from_cycle", type=int, default=0)
|
|
||||||
parser.add_argument("--to_cycle", type=int, default=100)
|
|
||||||
args = parser.parse_args()
|
|
||||||
return args
|
|
||||||
|
|
||||||
def parse_ginka_batch(batch):
|
|
||||||
target = batch["target"].to(device)
|
|
||||||
target_vision_feat = batch["target_vision_feat"].to(device).squeeze(1)
|
|
||||||
target_topo_feat = batch["target_topo_feat"].to(device).squeeze(1)
|
|
||||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=1).to(device)
|
|
||||||
|
|
||||||
return target, target_vision_feat, target_topo_feat, feat_vec
|
|
||||||
|
|
||||||
def parse_minamo_batch(batch):
|
|
||||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
|
||||||
map1 = map1.to(device) # 转为 [B, C, H, W]
|
|
||||||
map2 = map2.to(device)
|
|
||||||
topo_simi = topo_simi.to(device)
|
|
||||||
vision_simi = vision_simi.to(device)
|
|
||||||
graph1 = graph1.to(device)
|
|
||||||
graph2 = graph2.to(device)
|
|
||||||
return map1, map2, vision_simi, topo_simi, graph1, graph2
|
|
||||||
|
|
||||||
def send_all(sock, data):
|
|
||||||
total_sent = 0
|
|
||||||
while total_sent < len(data):
|
|
||||||
sent = sock.send(data[total_sent:])
|
|
||||||
if sent == 0:
|
|
||||||
raise RuntimeError("Socket connection broken")
|
|
||||||
total_sent += sent
|
|
||||||
|
|
||||||
def recv_all(sock: socket.socket, length: int):
|
|
||||||
"""循环接收直到获得指定长度的数据"""
|
|
||||||
data = bytes()
|
|
||||||
while len(data) < length:
|
|
||||||
packet = sock.recv(length - len(data)) # 只请求剩余部分
|
|
||||||
if not packet:
|
|
||||||
raise ConnectionError("连接中断")
|
|
||||||
data += packet
|
|
||||||
return data
|
|
||||||
|
|
||||||
def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
|
||||||
# 数据通讯 node 输出协议,单位字节:
|
|
||||||
# 2 - Tensor count; 2 - Review count. Review is right behind train data;
|
|
||||||
# 1*tc - Compare count for every map tensor delivered.
|
|
||||||
# 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
|
|
||||||
# N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor.
|
|
||||||
_, _, H, W = maps.shape
|
|
||||||
tc_buf = sock.recv(2)
|
|
||||||
rc_buf = sock.recv(2)
|
|
||||||
tc = struct.unpack('>h', tc_buf)[0]
|
|
||||||
rc = struct.unpack('>h', rc_buf)[0]
|
|
||||||
count_buf = recv_all(sock, 1 * tc)
|
|
||||||
count: list = struct.unpack(f">{tc}b", count_buf)
|
|
||||||
N = sum(count)
|
|
||||||
sim_buf = recv_all(sock, 2 * 4 * (N + rc))
|
|
||||||
com_buf = recv_all(sock, N * 1 * H * W)
|
|
||||||
review_buf = recv_all(sock, rc * 2 * H * W) if rc > 0 else bytes()
|
|
||||||
|
|
||||||
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)
|
|
||||||
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)
|
|
||||||
review = struct.unpack(f">{rc * 2 * H * W}", review_buf) if rc > 0 else list()
|
|
||||||
|
|
||||||
res = list()
|
|
||||||
flatten_idx = 0
|
|
||||||
# 读取当前这一轮生成器的数据
|
|
||||||
for idx in range(tc):
|
|
||||||
com_count = count[idx]
|
|
||||||
for i in range(com_count):
|
|
||||||
com_start = flatten_idx * H * W
|
|
||||||
com_end = (flatten_idx + 1) * H * W
|
|
||||||
vis_sim = sim[flatten_idx * 2]
|
|
||||||
topo_sim = sim[flatten_idx * 2 + 1]
|
|
||||||
com_data = com[com_start:com_end]
|
|
||||||
flatten_idx += 1
|
|
||||||
com_map = np.array(com_data, dtype=np.int8).reshape(H, W)
|
|
||||||
# map1, map2, vision_similarity, topo_similarity, is_review
|
|
||||||
res.append((maps[idx], com_map, vis_sim, topo_sim, False))
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
def train():
|
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
|
||||||
|
|
||||||
args = parse_arguments()
|
|
||||||
|
|
||||||
ginka = GinkaModel()
|
|
||||||
ginka.to(device)
|
|
||||||
minamo = MinamoModel(32)
|
|
||||||
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
|
||||||
minamo.to(device)
|
|
||||||
minamo.eval()
|
|
||||||
|
|
||||||
# 准备数据集
|
|
||||||
ginka_dataset = GinkaDataset(args.train, device, minamo)
|
|
||||||
ginka_dataset_val = GinkaDataset(args.validate, device, minamo)
|
|
||||||
minamo_dataset = MinamoGANDataset("datasets/minamo-dataset-1.json")
|
|
||||||
minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json")
|
|
||||||
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
|
||||||
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE // 2, shuffle=True)
|
|
||||||
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE // 2, shuffle=True)
|
|
||||||
|
|
||||||
# 设定优化器与调度器
|
|
||||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
|
|
||||||
criterion_ginka = GinkaLoss(minamo)
|
|
||||||
|
|
||||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9))
|
|
||||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=EPOCHS_MINAMO, T_mult=2, eta_min=1e-6)
|
|
||||||
criterion_minamo = MinamoLoss()
|
|
||||||
|
|
||||||
criterion = WGANGinkaLoss()
|
|
||||||
|
|
||||||
# 用于生成图片
|
|
||||||
tile_dict = dict()
|
|
||||||
for file in os.listdir('tiles'):
|
|
||||||
name = os.path.splitext(file)[0]
|
|
||||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
|
||||||
|
|
||||||
# 与 node 端通讯
|
|
||||||
if os.path.exists(SOCKET_PATH):
|
|
||||||
os.remove(SOCKET_PATH)
|
|
||||||
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
||||||
server.bind(SOCKET_PATH)
|
|
||||||
server.listen(1)
|
|
||||||
|
|
||||||
if args.resume:
|
|
||||||
data = torch.load(args.from_state, map_location=device)
|
|
||||||
ginka.load_state_dict(data["model_state"], strict=False)
|
|
||||||
print("Train from loaded state.")
|
|
||||||
|
|
||||||
print("Waiting for client connection...")
|
|
||||||
conn, _ = server.accept()
|
|
||||||
print("Client connected.")
|
|
||||||
|
|
||||||
for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"):
|
|
||||||
# -------------------- 训练生成器
|
|
||||||
for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model", leave=False):
|
|
||||||
ginka.train()
|
|
||||||
minamo.eval()
|
|
||||||
total_loss = 0
|
|
||||||
|
|
||||||
for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"):
|
|
||||||
# 数据迁移到设备
|
|
||||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
|
||||||
# 前向传播
|
|
||||||
optimizer_ginka.zero_grad()
|
|
||||||
_, output_softmax = ginka(feat_vec)
|
|
||||||
# 计算损失
|
|
||||||
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
|
|
||||||
# 反向传播
|
|
||||||
losses.backward()
|
|
||||||
optimizer_ginka.step()
|
|
||||||
total_loss += losses.item()
|
|
||||||
|
|
||||||
avg_loss = total_loss / len(ginka_dataloader)
|
|
||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
|
|
||||||
|
|
||||||
# 学习率调整
|
|
||||||
scheduler_ginka.step(epoch + 1)
|
|
||||||
|
|
||||||
if (epoch + 1) % 5 == 0:
|
|
||||||
loss_val = 0
|
|
||||||
ginka.eval()
|
|
||||||
idx = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch in tqdm(ginka_dataloader_val, leave=False, desc="Validating Ginka Model"):
|
|
||||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
|
||||||
output, output_softmax = ginka(feat_vec)
|
|
||||||
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
|
|
||||||
loss_val += losses.item()
|
|
||||||
if epoch + 1 == EPOCHS_GINKA:
|
|
||||||
# 最后一次验证的时候顺带生成图片
|
|
||||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
|
||||||
for matrix in map_matrix:
|
|
||||||
image = matrix_to_image_cv(matrix, tile_dict)
|
|
||||||
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
avg_val_loss = loss_val / len(ginka_dataloader_val)
|
|
||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
|
||||||
torch.save({
|
|
||||||
"model_state": ginka.state_dict()
|
|
||||||
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
|
||||||
|
|
||||||
# 使用训练集生成 minamo 训练数据,更准确
|
|
||||||
gen_list: np.ndarray = np.empty((0, 13, 13), np.int8)
|
|
||||||
prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32)
|
|
||||||
with torch.no_grad():
|
|
||||||
for batch in ginka_dataloader:
|
|
||||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
|
||||||
output, output_softmax = ginka(feat_vec)
|
|
||||||
prob = output_softmax.cpu().numpy()
|
|
||||||
prob_list = np.concatenate((prob_list, prob), axis=0)
|
|
||||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
|
||||||
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
|
|
||||||
|
|
||||||
tqdm.write(f"Cycle {cycle} Ginka train ended.")
|
|
||||||
torch.save({
|
|
||||||
"model_state": ginka.state_dict()
|
|
||||||
}, f"result/gan/ginka-{cycle}.pth")
|
|
||||||
torch.save({
|
|
||||||
"model_state": ginka.state_dict()
|
|
||||||
}, f"result/ginka.pth")
|
|
||||||
|
|
||||||
# -------------------- 生成 Minamo 的训练数据
|
|
||||||
|
|
||||||
# 数据通讯 python 输出协议,单位字节:
|
|
||||||
# 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
|
|
||||||
N, H, W = gen_list.shape
|
|
||||||
gen_bytes = gen_list.astype(np.int8).tobytes()
|
|
||||||
buf = bytearray()
|
|
||||||
buf.extend(struct.pack('>h', N)) # Tensor count
|
|
||||||
buf.extend(struct.pack('>b', H)) # Map height
|
|
||||||
buf.extend(struct.pack('>b', W)) # Map width
|
|
||||||
buf.extend(gen_bytes) # Map tensor
|
|
||||||
conn.sendall(buf)
|
|
||||||
data = parse_minamo_data(conn, prob_list)
|
|
||||||
|
|
||||||
vis_sim = 0
|
|
||||||
topo_sim = 0
|
|
||||||
for _, _, vis, topo, _ in data:
|
|
||||||
vis_sim += vis
|
|
||||||
topo_sim += topo
|
|
||||||
|
|
||||||
vis_sim /= len(data)
|
|
||||||
topo_sim /= len(data)
|
|
||||||
|
|
||||||
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
|
||||||
f.write(f'Cycle {cycle} | Ginka Vision Similarity: {vis_sim:.12f} | Ginka Topo Similarity: {topo_sim:.12f} | Ginka Loss: {avg_val_loss:.12f}')
|
|
||||||
|
|
||||||
# 经验回放部分
|
|
||||||
with open(REPLAY_PATH, 'r+b') as f:
|
|
||||||
# 读取文件开头获取总长度
|
|
||||||
f.seek(0)
|
|
||||||
count = struct.unpack('>i', f.read(4))[0] # 取出整数
|
|
||||||
if count > 0:
|
|
||||||
replay = np.random.choice(count, size=min(count, len(data) // 4), replace=False)
|
|
||||||
|
|
||||||
replay_data = np.empty((len(replay), 32, 13, 13))
|
|
||||||
for i, n in enumerate(replay):
|
|
||||||
f.seek(n * 32 * 13 * 13 + 4)
|
|
||||||
arr = np.frombuffer(f.read(32 * 13 * 13 * 4), dtype=np.float32).reshape(32, 13, 13)
|
|
||||||
replay_data[i] = arr
|
|
||||||
|
|
||||||
map_data: np.ndarray = replay_data.argmax(axis=1)
|
|
||||||
buf = bytearray()
|
|
||||||
buf.extend(struct.pack('>h', len(replay))) # Tensor count
|
|
||||||
buf.extend(struct.pack('>b', H)) # Map height
|
|
||||||
buf.extend(struct.pack('>b', W)) # Map width
|
|
||||||
buf.extend(map_data.astype(np.int8).tobytes()) # Map tensor
|
|
||||||
conn.sendall(buf)
|
|
||||||
data.extend(parse_minamo_data(conn, replay_data))
|
|
||||||
|
|
||||||
# 把新的内容写入文件末尾
|
|
||||||
to_write = np.random.choice(N, size=min(N, 100), replace=False)
|
|
||||||
write_data = bytearray()
|
|
||||||
for n in to_write:
|
|
||||||
write_data.extend(prob_list[n].tobytes())
|
|
||||||
|
|
||||||
f.seek(0, 2) # 定位到文件末尾
|
|
||||||
f.write(write_data)
|
|
||||||
|
|
||||||
f.seek(0) # 定位到文件开头
|
|
||||||
f.write(struct.pack('>i', count + len(to_write)))
|
|
||||||
f.flush() # 确保数据被刷新到磁盘
|
|
||||||
|
|
||||||
minamo_dataset.set_data(data)
|
|
||||||
|
|
||||||
# -------------------- 训练判别器
|
|
||||||
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
|
|
||||||
ginka.eval()
|
|
||||||
minamo.train()
|
|
||||||
total_loss = 0
|
|
||||||
|
|
||||||
for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"):
|
|
||||||
map1, map2, vis_sim, topo_sim, graph1, graph2 = parse_minamo_batch(batch)
|
|
||||||
batch_size = map1.shape[0]
|
|
||||||
|
|
||||||
if batch_size == 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 前向传播
|
|
||||||
optimizer_minamo.zero_grad()
|
|
||||||
vis_feat_real, topo_feat_real = minamo(map1, graph1)
|
|
||||||
vis_feat_ref, topo_feat_ref = minamo(map2, graph2)
|
|
||||||
|
|
||||||
# 生成假数据
|
|
||||||
with torch.no_grad():
|
|
||||||
fake_feat = torch.randn((batch_size, 1024), device=device)
|
|
||||||
fake_data = ginka(fake_feat)
|
|
||||||
|
|
||||||
# 创建插值样本
|
|
||||||
alpha = torch.rand((batch_size, 1, 1, 1), device=device)
|
|
||||||
interpolates = (alpha * map2 + (1 - alpha) * fake_data).requires_grad_(True)
|
|
||||||
|
|
||||||
vis_feat_fake, topo_feat_fake = minamo(fake_data)
|
|
||||||
vis_feat_interp, topo_feat_interp = minamo(interpolates)
|
|
||||||
|
|
||||||
vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1)
|
|
||||||
topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1)
|
|
||||||
vis_pred_fake = F.cosine_similarity(vis_feat_fake, vis_feat_ref, dim=1).unsqueeze(-1)
|
|
||||||
topo_pred_fake = F.cosine_similarity(topo_feat_fake, topo_feat_ref, dim=1).unsqueeze(-1)
|
|
||||||
vis_pred_interp = F.cosine_similarity(vis_feat_interp, vis_feat_ref, dim=1).unsqueeze(-1)
|
|
||||||
topo_pred_interp = F.cosine_similarity(topo_feat_interp, topo_feat_ref, dim=1).unsqueeze(-1)
|
|
||||||
|
|
||||||
# 计算相似度
|
|
||||||
score_real = F.l1_loss(vis_pred_real, vis_sim) * VISION_ALPHA + F.l1_loss(topo_pred_real, topo_sim) * (1 - VISION_ALPHA)
|
|
||||||
score_fake = vis_pred_fake * VISION_ALPHA + topo_pred_fake * (1 - VISION_ALPHA)
|
|
||||||
score_interp = vis_pred_interp * VISION_ALPHA + topo_pred_interp * (1 - VISION_ALPHA)
|
|
||||||
|
|
||||||
# 计算损失
|
|
||||||
loss = criterion.discriminator_loss(score_real, score_fake, score_interp)
|
|
||||||
|
|
||||||
# 反向传播
|
|
||||||
loss.backward()
|
|
||||||
optimizer_minamo.step()
|
|
||||||
total_loss += loss.item()
|
|
||||||
|
|
||||||
ave_loss = total_loss / len(minamo_dataloader)
|
|
||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_minamo.param_groups[0]['lr']):.6f}")
|
|
||||||
|
|
||||||
scheduler_minamo.step(epoch + 1)
|
|
||||||
|
|
||||||
# 每十轮推理一次验证集
|
|
||||||
if epoch + 1 == EPOCHS_MINAMO:
|
|
||||||
minamo.eval()
|
|
||||||
val_loss = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"):
|
|
||||||
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
|
|
||||||
|
|
||||||
vis_feat_real, topo_feat_real = minamo(map1_val, graph1)
|
|
||||||
vis_feat_ref, topo_feat_ref = minamo(map2_val, graph2)
|
|
||||||
|
|
||||||
vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1)
|
|
||||||
topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1)
|
|
||||||
|
|
||||||
# 计算损失
|
|
||||||
loss_val = criterion_minamo(vis_pred_real, topo_pred_real, vision_simi_val, topo_simi_val)
|
|
||||||
val_loss += loss_val.item()
|
|
||||||
|
|
||||||
avg_val_loss = val_loss / len(minamo_dataloader_val)
|
|
||||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
|
||||||
torch.save({
|
|
||||||
"model_state": minamo.state_dict()
|
|
||||||
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
|
|
||||||
|
|
||||||
tqdm.write(f"Cycle {cycle} Minamo train ended.")
|
|
||||||
torch.save({
|
|
||||||
"model_state": minamo.state_dict()
|
|
||||||
}, f"result/gan/minamo-{cycle}.pth")
|
|
||||||
torch.save({
|
|
||||||
"model_state": minamo.state_dict()
|
|
||||||
}, f"result/minamo.pth")
|
|
||||||
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
|
||||||
f.write(f' | Minamo: {avg_val_loss:.12f}\n')
|
|
||||||
|
|
||||||
print("Train ended.")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
torch.set_num_threads(4)
|
|
||||||
train()
|
|
||||||
@ -16,7 +16,7 @@ from shared.graph import batch_convert_soft_map_to_graph
|
|||||||
from shared.image import matrix_to_image_cv
|
from shared.image import matrix_to_image_cv
|
||||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||||
|
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 16
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
@ -30,39 +30,56 @@ def parse_arguments():
|
|||||||
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
|
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
|
||||||
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
|
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
|
||||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||||
|
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||||
parser.add_argument("--epochs", type=int, default=100)
|
parser.add_argument("--epochs", type=int, default=100)
|
||||||
parser.add_argument("--checkpoint", type=int, default=5)
|
parser.add_argument("--checkpoint", type=int, default=5)
|
||||||
parser.add_argument("--load_optim", type=bool, default=True)
|
parser.add_argument("--load_optim", type=bool, default=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def clip_weights(model, clip_value=0.01):
|
def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
for param in model.parameters():
|
fake1: torch.Tensor = gen(masked1, 1)
|
||||||
param.data = torch.clamp(param.data, -clip_value, clip_value)
|
fake2: torch.Tensor = gen(masked2, 2)
|
||||||
|
fake3: torch.Tensor = gen(masked3, 3)
|
||||||
|
if detach:
|
||||||
|
return fake1.detach(), fake2.detach(), fake3.detach()
|
||||||
|
else:
|
||||||
|
return fake1, fake2, fake3
|
||||||
|
|
||||||
|
def gen_total(gen, input, detach=False) -> torch.Tensor:
|
||||||
|
fake1 = gen(input, 1)
|
||||||
|
fake2 = gen(fake1, 2)
|
||||||
|
fake3 = gen(fake2, 3)
|
||||||
|
if detach:
|
||||||
|
return fake3.detach()
|
||||||
|
else:
|
||||||
|
return fake3
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||||
|
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
# c_steps = 1 if args.resume else 5
|
|
||||||
# g_steps = 5 if args.resume else 1
|
|
||||||
c_steps = 5
|
c_steps = 5
|
||||||
g_steps = 1
|
g_steps = 1
|
||||||
|
# 1 代表课程学习阶段,2 代表课程学习后,逐渐转为联合学习的阶段
|
||||||
|
# 3 代表课程学习后的联合遮挡学习阶段,4 代表最后随机输入的联合学习阶段
|
||||||
|
train_stage = 1
|
||||||
|
mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
||||||
|
random_ratio = 0
|
||||||
|
|
||||||
ginka = GinkaModel()
|
ginka = GinkaModel()
|
||||||
minamo = MinamoScoreModule()
|
minamo = MinamoScoreModule()
|
||||||
minamo_sim = MinamoSimilarityModel()
|
|
||||||
ginka.to(device)
|
ginka.to(device)
|
||||||
minamo.to(device)
|
minamo.to(device)
|
||||||
minamo_sim.to(device)
|
|
||||||
|
|
||||||
dataset = GinkaWGANDataset(args.train, device)
|
dataset = GinkaWGANDataset(args.train, device)
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
|
dataset_val = GinkaWGANDataset(args.validate, device)
|
||||||
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
|
|
||||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
||||||
optimizer_minamo_sim = optim.Adam(minamo_sim.parameters(), lr=1e-4)
|
|
||||||
|
|
||||||
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
||||||
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
|
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
|
||||||
@ -82,87 +99,134 @@ def train():
|
|||||||
ginka.load_state_dict(data_ginka["model_state"], strict=False)
|
ginka.load_state_dict(data_ginka["model_state"], strict=False)
|
||||||
minamo.load_state_dict(data_minamo["model_state"], strict=False)
|
minamo.load_state_dict(data_minamo["model_state"], strict=False)
|
||||||
|
|
||||||
if data_ginka["c_steps"] is not None and data_ginka["g_steps"] is not None:
|
if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None:
|
||||||
c_steps = data_ginka["c_steps"]
|
c_steps = data_ginka["c_steps"]
|
||||||
g_steps = data_ginka["g_steps"]
|
g_steps = data_ginka["g_steps"]
|
||||||
|
|
||||||
|
if data_ginka.get("mask_ratio") is not None:
|
||||||
|
mask_ratio = data_ginka["mask_ratio"]
|
||||||
|
|
||||||
|
if data_ginka.get("random_ratio") is not None:
|
||||||
|
random_ratio = data_ginka["random_ratio"]
|
||||||
|
|
||||||
|
if data_ginka.get("stage") is not None:
|
||||||
|
train_stage = data_ginka["stage"]
|
||||||
|
|
||||||
if args.load_optim:
|
if args.load_optim:
|
||||||
if data_ginka["optim_state"] is not None:
|
if data_ginka.get("optim_state") is not None:
|
||||||
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
||||||
if data_minamo["optim_state"] is not None:
|
if data_minamo.get("optim_state") is not None:
|
||||||
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
|
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
|
||||||
if data_minamo["optim_state_sim"] is not None:
|
|
||||||
optimizer_minamo_sim.load_state_dict(data_minamo["optim_state_sim"])
|
dataset.train_stage = train_stage
|
||||||
|
dataset.mask_ratio1 = mask_ratio
|
||||||
|
dataset.mask_ratio2 = mask_ratio
|
||||||
|
dataset.mask_ratio3 = mask_ratio
|
||||||
|
dataset.random_ratio = random_ratio
|
||||||
|
|
||||||
|
dataset_val.train_stage = train_stage
|
||||||
|
dataset_val.mask_ratio1 = mask_ratio
|
||||||
|
dataset_val.mask_ratio2 = mask_ratio
|
||||||
|
dataset_val.mask_ratio3 = mask_ratio
|
||||||
|
dataset_val.random_ratio = random_ratio
|
||||||
|
|
||||||
print("Train from loaded state.")
|
print("Train from loaded state.")
|
||||||
|
|
||||||
|
low_loss_epochs = 0
|
||||||
|
|
||||||
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
||||||
loss_total_minamo = torch.Tensor([0]).to(device)
|
loss_total_minamo = torch.Tensor([0]).to(device)
|
||||||
loss_total_minamo_sim = torch.Tensor([0]).to(device)
|
|
||||||
loss_total_ginka = torch.Tensor([0]).to(device)
|
loss_total_ginka = torch.Tensor([0]).to(device)
|
||||||
dis_total = torch.Tensor([0]).to(device)
|
dis_total = torch.Tensor([0]).to(device)
|
||||||
|
loss_ce_total = torch.Tensor([0]).to(device)
|
||||||
|
|
||||||
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||||
batch_size = real_data.size(0)
|
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
|
||||||
real_data = real_data.to(device)
|
|
||||||
real_graph = batch_convert_soft_map_to_graph(real_data)
|
|
||||||
|
|
||||||
# ---------- 训练判别器
|
# ---------- 训练判别器
|
||||||
for _ in range(c_steps):
|
for _ in range(c_steps):
|
||||||
# 生成假样本
|
# 生成假样本
|
||||||
optimizer_minamo.zero_grad()
|
optimizer_minamo.zero_grad()
|
||||||
z = torch.rand(batch_size, 1024, device=device)
|
optimizer_ginka.zero_grad()
|
||||||
fake_data = ginka(z)
|
if train_stage == 1 or train_stage == 2:
|
||||||
fake_data = fake_data.detach()
|
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||||
|
|
||||||
|
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
|
||||||
|
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
|
||||||
|
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
|
||||||
|
|
||||||
|
dis_avg = (dis1 + dis2 + dis3) / 3.0
|
||||||
|
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
|
||||||
|
|
||||||
# 计算判别器输出
|
|
||||||
# 反向传播
|
# 反向传播
|
||||||
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
|
loss_d_avg.backward()
|
||||||
loss_d.backward()
|
elif train_stage == 3:
|
||||||
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=2.0)
|
pass
|
||||||
# total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2)
|
|
||||||
# print("Critic 梯度范数:", total_norm.item())
|
|
||||||
# print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item())
|
|
||||||
# print("Critic 输出范围:", d_real.min().item(), d_real.max().item())
|
|
||||||
optimizer_minamo.step()
|
optimizer_minamo.step()
|
||||||
|
|
||||||
loss_total_minamo += loss_d.detach()
|
loss_total_minamo += loss_d_avg.detach()
|
||||||
dis_total += dis.detach()
|
dis_total += dis_avg.detach()
|
||||||
|
|
||||||
# ---------- 训练生成器
|
# ---------- 训练生成器
|
||||||
|
|
||||||
for _ in range(g_steps):
|
for _ in range(g_steps):
|
||||||
|
optimizer_minamo.zero_grad()
|
||||||
optimizer_ginka.zero_grad()
|
optimizer_ginka.zero_grad()
|
||||||
# optimizer_minamo_sim.zero_grad()
|
if train_stage == 1 or train_stage == 2:
|
||||||
|
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False)
|
||||||
|
|
||||||
z1 = torch.randn(batch_size, 1024, device=device)
|
loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1)
|
||||||
z2 = torch.randn(batch_size, 1024, device=device)
|
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2)
|
||||||
fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2)
|
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3)
|
||||||
|
|
||||||
# 先训练辅助判别器
|
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
||||||
# loss_c_assist = criterion.discriminator_loss_assist2(minamo_sim, real_data, fake_softmax1, fake_softmax2)
|
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
|
||||||
# loss_c_assist.backward(retain_graph=True)
|
|
||||||
# optimizer_minamo_sim.step()
|
|
||||||
|
|
||||||
loss_g = criterion.generator_loss(minamo, minamo_sim, fake_softmax1, fake_softmax2)
|
|
||||||
loss_g.backward()
|
loss_g.backward()
|
||||||
optimizer_ginka.step()
|
optimizer_ginka.step()
|
||||||
|
loss_total_ginka += loss_g.detach()
|
||||||
|
loss_ce_total += loss_ce.detach()
|
||||||
|
|
||||||
loss_total_ginka += loss_g
|
elif train_stage == 3:
|
||||||
# loss_total_minamo_sim += loss_c_assist.detach()
|
pass
|
||||||
# tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}")
|
|
||||||
|
|
||||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
||||||
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
||||||
avg_loss_minamo_sim = loss_total_minamo_sim.item() / len(dataloader) / g_steps
|
avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps
|
||||||
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
||||||
tqdm.write(
|
tqdm.write(
|
||||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +\
|
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||||
f"Epoch: {epoch + 1} | W Loss: {avg_dis:.8f} | " +\
|
f"Epoch: {epoch + 1} | W: {avg_dis:.8f} | " +
|
||||||
f"G Loss: {avg_loss_ginka:.8f} | D Loss: {avg_loss_minamo:.8f} | " +\
|
f"G: {avg_loss_ginka:.8f} | D: {avg_loss_minamo:.8f} | " +
|
||||||
f"lr G: {(optimizer_ginka.param_groups[0]['lr']):.8f}"
|
f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if avg_loss_ce < 0.5:
|
||||||
|
low_loss_epochs += 1
|
||||||
|
else:
|
||||||
|
low_loss_epochs = 0
|
||||||
|
|
||||||
|
if low_loss_epochs >= 5 and train_stage == 2:
|
||||||
|
random_ratio += 0.1
|
||||||
|
random_ratio = min(random_ratio, 0.5)
|
||||||
|
low_loss_epochs = 0
|
||||||
|
|
||||||
|
if low_loss_epochs >= 5 and train_stage == 1:
|
||||||
|
if mask_ratio >= 0.9:
|
||||||
|
train_stage = 2
|
||||||
|
|
||||||
|
mask_ratio += 0.1
|
||||||
|
mask_ratio = min(mask_ratio, 0.9)
|
||||||
|
low_loss_epochs = 0
|
||||||
|
|
||||||
|
dataset.train_stage = 2
|
||||||
|
dataset_val.train_stage = 2
|
||||||
|
dataset.random_ratio = random_ratio
|
||||||
|
dataset_val.random_ratio = random_ratio
|
||||||
|
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
||||||
|
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
||||||
|
|
||||||
# scheduler_ginka.step()
|
# scheduler_ginka.step()
|
||||||
# scheduler_minamo.step()
|
# scheduler_minamo.step()
|
||||||
|
|
||||||
@ -172,46 +236,51 @@ def train():
|
|||||||
g_steps = 1
|
g_steps = 1
|
||||||
|
|
||||||
if avg_loss_ginka > 0 or avg_loss_minamo > 0:
|
if avg_loss_ginka > 0 or avg_loss_minamo > 0:
|
||||||
c_steps = min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15)
|
c_steps = int(max(min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15), 1))
|
||||||
else:
|
else:
|
||||||
c_steps = 5
|
c_steps = 5
|
||||||
|
|
||||||
# 每若干轮输出一次图片,并保存检查点
|
# 每若干轮输出一次图片,并保存检查点
|
||||||
if (epoch + 1) % args.checkpoint == 0:
|
if (epoch + 1) % args.checkpoint == 0:
|
||||||
# 输出 20 张图片,每批次 4 张,一共五批
|
|
||||||
idx = 0
|
|
||||||
with torch.no_grad():
|
|
||||||
for _ in range(5):
|
|
||||||
z = torch.randn(4, 1024, device=device)
|
|
||||||
output = ginka(z)
|
|
||||||
|
|
||||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
|
||||||
for matrix in map_matrix:
|
|
||||||
image = matrix_to_image_cv(matrix, tile_dict)
|
|
||||||
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
|
||||||
idx += 1
|
|
||||||
|
|
||||||
# 保存检查点
|
# 保存检查点
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": ginka.state_dict(),
|
"model_state": ginka.state_dict(),
|
||||||
"optim_state": optimizer_ginka.state_dict(),
|
"optim_state": optimizer_ginka.state_dict(),
|
||||||
"c_steps": c_steps,
|
"c_steps": c_steps,
|
||||||
"g_steps": g_steps
|
"g_steps": g_steps,
|
||||||
|
"stage": train_stage,
|
||||||
|
"mask_ratio": mask_ratio,
|
||||||
|
"random_ratio": random_ratio,
|
||||||
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": minamo.state_dict(),
|
"model_state": minamo.state_dict(),
|
||||||
"model_state_sim": minamo_sim.state_dict(),
|
"optim_state": optimizer_minamo.state_dict()
|
||||||
"optim_state": optimizer_minamo.state_dict(),
|
|
||||||
"optim_state_sim": optimizer_minamo_sim.state_dict()
|
|
||||||
}, f"result/wgan/minamo-{epoch + 1}.pth")
|
}, f"result/wgan/minamo-{epoch + 1}.pth")
|
||||||
|
|
||||||
|
idx = 0
|
||||||
|
with torch.no_grad():
|
||||||
|
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||||
|
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
|
||||||
|
if train_stage == 1:
|
||||||
|
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||||
|
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
||||||
|
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
||||||
|
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
|
||||||
|
|
||||||
|
for i in range(fake1.shape[0]):
|
||||||
|
for key, one in enumerate([fake1, fake2, fake3]):
|
||||||
|
map_matrix = one[i]
|
||||||
|
image = matrix_to_image_cv(map_matrix, tile_dict)
|
||||||
|
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
|
||||||
|
|
||||||
|
idx += 1
|
||||||
|
|
||||||
print("Train ended.")
|
print("Train ended.")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": ginka.state_dict(),
|
"model_state": ginka.state_dict(),
|
||||||
}, f"result/ginka.pth")
|
}, f"result/ginka.pth")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": minamo.state_dict(),
|
"model_state": minamo.state_dict(),
|
||||||
"model_state_sim": minamo_sim.state_dict(),
|
|
||||||
}, f"result/minamo.pth")
|
}, f"result/minamo.pth")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@ -20,23 +20,41 @@ class MinamoModel(nn.Module):
|
|||||||
|
|
||||||
return vision_feat, topo_feat
|
return vision_feat, topo_feat
|
||||||
|
|
||||||
|
class MinamoScoreHead(nn.Module):
|
||||||
|
def __init__(self, in_dim, out_dim):
|
||||||
|
super().__init__()
|
||||||
|
self.vision_fc = nn.Sequential(
|
||||||
|
spectral_norm(nn.Linear(in_dim, out_dim)),
|
||||||
|
)
|
||||||
|
self.topo_fc = nn.Sequential(
|
||||||
|
spectral_norm(nn.Linear(in_dim, out_dim))
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, vis_feat, topo_feat):
|
||||||
|
vis_score = self.vision_fc(vis_feat)
|
||||||
|
topo_score = self.topo_fc(topo_feat)
|
||||||
|
return vis_score, topo_score
|
||||||
|
|
||||||
class MinamoScoreModule(nn.Module):
|
class MinamoScoreModule(nn.Module):
|
||||||
def __init__(self, tile_types=32):
|
def __init__(self, tile_types=32):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.topo_model = MinamoTopoModel(tile_types)
|
self.topo_model = MinamoTopoModel(tile_types)
|
||||||
self.vision_model = MinamoVisionModel(tile_types)
|
self.vision_model = MinamoVisionModel(tile_types)
|
||||||
# 输出层
|
# 输出层
|
||||||
self.topo_fc = nn.Sequential(
|
self.head1 = MinamoScoreHead(512, 1)
|
||||||
spectral_norm(nn.Linear(512, 1)),
|
self.head2 = MinamoScoreHead(512, 1)
|
||||||
)
|
self.head3 = MinamoScoreHead(512, 1)
|
||||||
self.vision_fc = nn.Sequential(
|
|
||||||
spectral_norm(nn.Linear(512, 1)),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, map, graph):
|
def forward(self, map, graph, stage):
|
||||||
topo_feat = self.topo_model(graph)
|
|
||||||
topo_score = self.topo_fc(topo_feat)
|
|
||||||
vision_feat = self.vision_model(map)
|
vision_feat = self.vision_model(map)
|
||||||
vision_score = self.vision_fc(vision_feat)
|
topo_feat = self.topo_model(graph)
|
||||||
|
if stage == 1:
|
||||||
|
vision_score, topo_score = self.head1(vision_feat, topo_feat)
|
||||||
|
elif stage == 2:
|
||||||
|
vision_score, topo_score = self.head2(vision_feat, topo_feat)
|
||||||
|
elif stage == 3:
|
||||||
|
vision_score, topo_score = self.head3(vision_feat, topo_feat)
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Unknown critic stage.")
|
||||||
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
|
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
|
||||||
return score, vision_score, topo_score
|
return score, vision_score, topo_score
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user