mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 22:41:14 +08:00
refactor: 课程训练
This commit is contained in:
parent
99f46150be
commit
f6b1ad6ebd
137
ginka/dataset.py
137
ginka/dataset.py
@ -3,10 +3,18 @@ import random
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.graph import differentiable_convert_to_data
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import List
|
||||
from shared.utils import random_smooth_onehot
|
||||
|
||||
STAGE1_MASK = [0, 1, 10, 11]
|
||||
STAGE1_REMOVE = [2, 3, 4, 5, 6, 7, 8, 9, 12]
|
||||
STAGE2_MASK = [6, 7, 8, 9]
|
||||
STAGE2_REMOVE = [2, 3, 4, 5, 12]
|
||||
STAGE3_MASK = [2, 3, 4, 5, 12]
|
||||
STAGE3_REMOVE = []
|
||||
|
||||
def load_data(path: str):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
@ -23,38 +31,45 @@ def load_minamo_gan_data(data: list):
|
||||
res.append((one['map1'], one['map2'], one['visionSimilarity'], one['topoSimilarity'], True))
|
||||
return res
|
||||
|
||||
class GinkaDataset(Dataset):
|
||||
def __init__(self, data_path: str, device, minamo: MinamoModel):
|
||||
self.data = load_data(data_path) # 自定义数据加载函数
|
||||
self.max_size = 32
|
||||
self.minamo = minamo
|
||||
self.device = device
|
||||
def apply_curriculum_mask(
|
||||
maps: torch.Tensor, # [B, C, H, W]
|
||||
mask_classes: List[int], # 要遮挡的类别索引
|
||||
remove_classes: List[int], # 要移除的类别索引
|
||||
mask_ratio: float # 遮挡比例 0~1
|
||||
) -> torch.Tensor:
|
||||
C, H, W = maps.shape
|
||||
device = maps.device
|
||||
masked_maps = maps.clone()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
# Step 1: 移除不需要的类别(全设为 0 类)
|
||||
if remove_classes:
|
||||
remove_mask = masked_maps[remove_classes, :, :].sum(dim=0, keepdim=True) > 0
|
||||
masked_maps[:, :, :][remove_mask.expand(C, -1, -1)] = 0
|
||||
masked_maps[0][remove_mask[0, :, :]] = 1 # 设置为“空地”
|
||||
|
||||
removed_maps = masked_maps.clone()
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
min_main = random.uniform(0.75, 0.9)
|
||||
max_main = random.uniform(0.9, 1)
|
||||
epsilon = random.uniform(0, 0.25)
|
||||
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon)
|
||||
graph = differentiable_convert_to_data(target_smooth).to(self.device)
|
||||
target = target.to(self.device)
|
||||
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
||||
|
||||
return {
|
||||
"target_vision_feat": vision_feat,
|
||||
"target_topo_feat": topo_feat,
|
||||
"target": target,
|
||||
}
|
||||
# Step 2: 对指定类别随机遮挡
|
||||
for cls in mask_classes:
|
||||
cls_mask = masked_maps[:, cls] > 0 # 目标类别的像素布尔掩码 [H, W]
|
||||
indices = cls_mask.nonzero(as_tuple=False) # 所有该类像素坐标
|
||||
num_mask = int(len(indices) * mask_ratio)
|
||||
if num_mask > 0:
|
||||
selected = indices[torch.randperm(len(indices))[:num_mask]]
|
||||
masked_maps[cls, selected[:, 0], selected[:, 1]] = 0
|
||||
masked_maps[0, selected[:, 0], selected[:, 1]] = 1 # 置为“空地”
|
||||
|
||||
return removed_maps, masked_maps
|
||||
|
||||
class GinkaWGANDataset(Dataset):
|
||||
def __init__(self, data_path: str, device):
|
||||
self.data = load_data(data_path) # 自定义数据加载函数
|
||||
self.device = device
|
||||
self.train_stage = 1
|
||||
self.mask_ratio1 = 0.1
|
||||
self.mask_ratio2 = 0.1
|
||||
self.mask_ratio3 = 0.1
|
||||
self.random_ratio = 0.0
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
@ -63,56 +78,20 @@ class GinkaWGANDataset(Dataset):
|
||||
item = self.data[idx]
|
||||
|
||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
# min_main = random.uniform(0.8, 0.9)
|
||||
# max_main = random.uniform(0.9, 1)
|
||||
# epsilon = random.uniform(0, 0.2)
|
||||
# target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device)
|
||||
|
||||
if self.train_stage == 1:
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, self.mask_ratio1)
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, self.mask_ratio2)
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, self.mask_ratio3)
|
||||
elif self.train_stage == 2:
|
||||
removed1, masked1 = apply_curriculum_mask(target, STAGE1_MASK, STAGE1_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed2, masked2 = apply_curriculum_mask(target, STAGE2_MASK, STAGE2_REMOVE, random.uniform(0.1, 0.9))
|
||||
removed3, masked3 = apply_curriculum_mask(target, STAGE3_MASK, STAGE3_REMOVE, random.uniform(0.1, 0.9))
|
||||
|
||||
return target
|
||||
|
||||
class MinamoGANDataset(Dataset):
|
||||
def __init__(self, refer_data_path):
|
||||
self.refer = load_minamo_gan_data(load_data(refer_data_path))
|
||||
self.data = list()
|
||||
self.data.extend(random.sample(self.refer, 1000))
|
||||
|
||||
def set_data(self, data: list):
|
||||
self.data.clear()
|
||||
self.data.extend(data)
|
||||
k = min(len(data) / 4, len(self.refer))
|
||||
self.data.extend(random.sample(self.refer, int(k)))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# 假定 map2 是参考地图
|
||||
item = self.data[idx]
|
||||
|
||||
map1, map2, vis_sim, topo_sim, review = item
|
||||
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
|
||||
if review:
|
||||
map1 = F.one_hot(torch.LongTensor(map1), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
else:
|
||||
map1 = torch.FloatTensor(map1)
|
||||
map2 = F.one_hot(torch.LongTensor(map2), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
|
||||
min_main = random.uniform(0.75, 0.9)
|
||||
max_main = random.uniform(0.9, 1)
|
||||
epsilon = random.uniform(0, 0.25)
|
||||
|
||||
if review:
|
||||
map1 = random_smooth_onehot(map1, min_main, max_main, epsilon)
|
||||
map2 = random_smooth_onehot(map2, min_main, max_main, epsilon)
|
||||
|
||||
graph1 = differentiable_convert_to_data(map1)
|
||||
graph2 = differentiable_convert_to_data(map2)
|
||||
|
||||
return (
|
||||
map1,
|
||||
map2,
|
||||
torch.FloatTensor([vis_sim]),
|
||||
torch.FloatTensor([topo_sim]),
|
||||
graph1,
|
||||
graph2
|
||||
)
|
||||
if self.random_ratio > 0:
|
||||
removed1 = random_smooth_onehot(removed1, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
||||
removed2 = random_smooth_onehot(removed2, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
||||
removed3 = random_smooth_onehot(removed3, min_main=1 - self.random_ratio, max_main=1.0, epsilon=self.random_ratio)
|
||||
|
||||
return removed1, masked1, removed2, masked2, removed3, masked3
|
||||
|
||||
@ -2,43 +2,24 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class GinkaInput(nn.Module):
|
||||
def __init__(self, feat_dim=1024, out_ch=1, size=(32, 32)):
|
||||
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
|
||||
super().__init__()
|
||||
self.out_size = out_size
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(feat_dim, size[0] * size[1] * out_ch),
|
||||
nn.Unflatten(1, (out_ch, *size))
|
||||
nn.Linear(in_size[0] * in_size[1], out_size[0] * out_size[1]),
|
||||
nn.LayerNorm(out_size[0] * out_size[1]),
|
||||
nn.ELU()
|
||||
)
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
x = x.view(B, C, H * W)
|
||||
x = self.fc(x)
|
||||
x = x.view(B, C, self.out_size[0], self.out_size[1])
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class FeatureEncoder(nn.Module):
|
||||
def __init__(self, feat_dim, size, mid_ch, out_ch):
|
||||
super().__init__()
|
||||
self.encode = nn.Sequential(
|
||||
nn.Linear(feat_dim, mid_ch * size * size),
|
||||
nn.Unflatten(1, (mid_ch, size, size)),
|
||||
nn.Conv2d(mid_ch, out_ch, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encode(x)
|
||||
return x
|
||||
|
||||
class GinkaFeatureInput(nn.Module):
|
||||
def __init__(self, feat_dim=1024, mid_ch=1, out_ch=64):
|
||||
super().__init__()
|
||||
self.encode1 = FeatureEncoder(feat_dim, 32, mid_ch, out_ch)
|
||||
self.encode2 = FeatureEncoder(feat_dim, 16, mid_ch * 2, out_ch * 2)
|
||||
self.encode3 = FeatureEncoder(feat_dim, 8, mid_ch * 4, out_ch * 4)
|
||||
self.encode4 = FeatureEncoder(feat_dim, 4, mid_ch * 8, out_ch * 8)
|
||||
self.encode5 = FeatureEncoder(feat_dim, 2, mid_ch * 16, out_ch * 16)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.encode1(x)
|
||||
x2 = self.encode2(x)
|
||||
x3 = self.encode3(x)
|
||||
x4 = self.encode4(x)
|
||||
x5 = self.encode5(x)
|
||||
return x1, x2, x3, x4, x5
|
||||
|
||||
@ -13,6 +13,13 @@ from shared.similarity.vision import calculate_visual_similarity
|
||||
CLASS_NUM = 32
|
||||
ILLEGAL_MAX_NUM = 12
|
||||
|
||||
STAGE_ALLOWED = [
|
||||
[],
|
||||
[0, 1, 10, 11],
|
||||
[6, 7, 8, 9,],
|
||||
[2, 3, 4, 5, 12]
|
||||
]
|
||||
|
||||
def get_not_allowed(classes: list[int], include_illegal=False):
|
||||
res = list()
|
||||
for num in range(0, CLASS_NUM):
|
||||
@ -301,24 +308,47 @@ def js_divergence(p, q, eps=1e-8):
|
||||
|
||||
return torch.clamp(0.5 * (kl_pm + kl_qm), max=1.0)
|
||||
|
||||
def immutable_penalty_loss(
|
||||
pred: torch.Tensor, input: torch.Tensor, modifiable_classes: list[int]
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
惩罚模型修改不可更改区域的损失。
|
||||
|
||||
Args:
|
||||
input: 模型输出 [B, C, H, W],概率分布 (softmax 后)
|
||||
target: 原始输入图 [B, C, H, W],概率分布 (softmax 后)
|
||||
modifiable_classes: 允许被修改的类别列表
|
||||
penalty_weight: 对非允许修改区域的惩罚系数
|
||||
"""
|
||||
not_allowed = get_not_allowed(modifiable_classes, include_illegal=True)
|
||||
input_mask = pred[:, not_allowed, :, :]
|
||||
with torch.no_grad():
|
||||
target_mask = torch.argmax(input[:, not_allowed, :, :], dim=1)
|
||||
target_mask = F.one_hot(target_mask, num_classes=len(not_allowed)).permute(0, 3, 1, 2).float()
|
||||
|
||||
# 差异区域(模型试图改变的地方)
|
||||
penalty = F.cross_entropy(input_mask, target_mask)
|
||||
|
||||
return penalty
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=100, weight=[0.8, 0.1, 0.1], diversity_lamda=0.4):
|
||||
def __init__(self, lambda_gp=100, weight=[1, 0.4, 10, 0.2, 0.2]):
|
||||
# weight: 判别器损失,L1 损失,不可修改类型损失
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
self.diversity_lamda = diversity_lamda
|
||||
|
||||
def compute_gradient_penalty(self, critic, real_data, fake_data):
|
||||
def compute_gradient_penalty(self, critic, stage, real_data, fake_data):
|
||||
# 进行插值
|
||||
batch_size = real_data.size(0)
|
||||
epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device)
|
||||
interp_data = interpolate_data(real_data, fake_data, epsilon_data)
|
||||
interp_graph = batch_convert_soft_map_to_graph(interp_data)
|
||||
interp_data = interpolate_data(real_data, fake_data, epsilon_data).to(real_data.device)
|
||||
interp_graph = batch_convert_soft_map_to_graph(interp_data).to(real_data.device)
|
||||
|
||||
# 对图像进行反向传播并计算梯度
|
||||
interp_data.requires_grad_()
|
||||
interp_graph.x.requires_grad_()
|
||||
|
||||
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph)
|
||||
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph, stage)
|
||||
|
||||
# 计算梯度
|
||||
grad_vis = torch.autograd.grad(
|
||||
@ -344,21 +374,21 @@ class WGANGinkaLoss:
|
||||
return gp_loss
|
||||
|
||||
def discriminator_loss(
|
||||
self, critic, real_data: torch.Tensor,
|
||||
real_graph: torch.Tensor, fake_data: torch.Tensor
|
||||
):
|
||||
self, critic, stage: int, real_data: torch.Tensor, fake_data: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
""" 判别器损失函数 """
|
||||
real_graph = batch_convert_soft_map_to_graph(real_data)
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake_data)
|
||||
real_scores, _, _ = critic(real_data, real_graph)
|
||||
fake_scores, _, _ = critic(fake_data, fake_graph)
|
||||
|
||||
# print("Critic 输出范围", fake_scores.min().item(), fake_scores.max().item(), real_scores.min().item(), real_scores.max().item())
|
||||
real_scores, _, _ = critic(real_data, real_graph, stage)
|
||||
fake_scores, _, _ = critic(fake_data, fake_graph, stage)
|
||||
|
||||
# Wasserstein 距离
|
||||
d_loss = fake_scores.mean() - real_scores.mean()
|
||||
grad_loss = self.compute_gradient_penalty(critic, real_data, fake_data)
|
||||
grad_loss = self.compute_gradient_penalty(critic, stage, real_data, fake_data)
|
||||
|
||||
return d_loss, d_loss + self.lambda_gp * grad_loss
|
||||
total_loss = d_loss + self.lambda_gp * grad_loss
|
||||
|
||||
return total_loss, d_loss
|
||||
|
||||
def calculate_similarity_one(self, map1, map2):
|
||||
topo1 = build_topological_graph(map1)
|
||||
@ -368,73 +398,29 @@ class WGANGinkaLoss:
|
||||
topo_sim = overall_similarity(topo1, topo2)
|
||||
|
||||
return vis_sim, topo_sim
|
||||
|
||||
def discriminator_loss_assist(self, critic, fake_data1, fake_data2):
|
||||
graph1 = batch_convert_soft_map_to_graph(fake_data1)
|
||||
graph2 = batch_convert_soft_map_to_graph(fake_data2)
|
||||
vis_feat_1, topo_feat_1 = critic(fake_data1, graph1)
|
||||
vis_feat_2, topo_feat_2 = critic(fake_data2, graph2)
|
||||
|
||||
def generator_loss(self, critic, stage, mask_ratio, real, fake, input) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
""" 生成器损失函数 """
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake)
|
||||
|
||||
batch1 = torch.argmax(fake_data1, dim=1).cpu().tolist()
|
||||
batch2 = torch.argmax(fake_data2, dim=1).cpu().tolist()
|
||||
|
||||
vis_sim_real = []
|
||||
topo_sim_real = []
|
||||
|
||||
for i in range(len(batch1)):
|
||||
vis_sim, topo_sim = self.calculate_similarity_one(batch1[i], batch2[i])
|
||||
vis_sim_real.append(vis_sim)
|
||||
topo_sim_real.append(topo_sim)
|
||||
|
||||
vis_sim_real = torch.Tensor(vis_sim_real)
|
||||
topo_sim_real = torch.Tensor(topo_sim_real)
|
||||
|
||||
pred_vis_sim = F.cosine_similarity(vis_feat_1, vis_feat_2).cpu()
|
||||
pred_topo_sim = F.cosine_similarity(topo_feat_1, topo_feat_2).cpu()
|
||||
|
||||
loss1 = F.l1_loss(pred_vis_sim, vis_sim_real) * VISION_WEIGHT + F.l1_loss(pred_topo_sim, topo_sim_real) * TOPO_WEIGHT
|
||||
|
||||
return loss1
|
||||
|
||||
def discriminator_loss_assist2(self, critic, real_data, fake_data1, fake_data2):
|
||||
loss1 = self.discriminator_loss_assist(critic, real_data, fake_data1)
|
||||
loss2 = self.discriminator_loss_assist(critic, real_data, fake_data2)
|
||||
loss3 = self.discriminator_loss_assist(critic, fake_data1, fake_data2)
|
||||
|
||||
return loss1 / 3.0 + loss2 / 3.0 + loss3 / 3.0
|
||||
|
||||
def generator_loss_one(self, critic, fake, fake_graph):
|
||||
fake_scores, _, _ = critic(fake, fake_graph)
|
||||
fake_scores, _, _ = critic(fake, fake_graph, stage)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||
entrance_loss = entrance_constraint_loss(fake)
|
||||
ce_loss = F.cross_entropy(fake, real)
|
||||
immutable_loss = immutable_penalty_loss(fake, input, STAGE_ALLOWED[stage])
|
||||
constraint_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
class_loss * self.weight[1],
|
||||
entrance_loss * self.weight[2]
|
||||
ce_loss * self.weight[1] / mask_ratio * (1 - mask_ratio), # 蒙版越大,交叉熵损失权重越小
|
||||
immutable_loss * self.weight[2],
|
||||
constraint_loss * self.weight[3]
|
||||
]
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def generator_loss(self, critic, critic_assist, fake1, fake2):
|
||||
""" 生成器损失函数 """
|
||||
fake_graph1 = batch_convert_soft_map_to_graph(fake1)
|
||||
fake_graph2 = batch_convert_soft_map_to_graph(fake2)
|
||||
if stage == 1:
|
||||
# 第一个阶段检查入口存在性
|
||||
entrance_loss = entrance_constraint_loss(fake)
|
||||
losses.append(entrance_loss * self.weight[4])
|
||||
|
||||
loss1 = self.generator_loss_one(critic, fake1, fake_graph1)
|
||||
loss2 = self.generator_loss_one(critic, fake2, fake_graph2)
|
||||
# print(losses[2].item())
|
||||
|
||||
# vis_feat1, topo_feat1 = critic_assist(fake1, fake_graph1)
|
||||
# vis_feat2, topo_feat2 = critic_assist(fake2, fake_graph2)
|
||||
|
||||
# vis_sim = F.cosine_similarity(vis_feat1, vis_feat2)
|
||||
# topo_sim = F.cosine_similarity(topo_feat1, topo_feat2)
|
||||
# similarity = vis_sim * VISION_WEIGHT + topo_sim * TOPO_WEIGHT
|
||||
|
||||
# print(similarity.mean().item())
|
||||
# div_loss = F.l1_loss(fake1[:, :, 1:-1, 1:-1], fake2[:, :, 1:-1, 1:-1])
|
||||
|
||||
return loss1 * 0.5 + loss2 * 0.5\
|
||||
# + self.diversity_lamda * F.relu(0.7 - div_loss).mean()
|
||||
# + self.diversity_lamda * F.relu(similarity - 0.4).mean()
|
||||
return sum(losses), minamo_loss, ce_loss / mask_ratio, immutable_loss
|
||||
|
||||
@ -9,27 +9,29 @@ def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||
|
||||
class GinkaModel(nn.Module):
|
||||
def __init__(self, feat_dim=1024, base_ch=64, out_ch=32):
|
||||
def __init__(self, base_ch=64, out_ch=32):
|
||||
"""Ginka Model 模型定义部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.unet = GinkaUNet(base_ch, base_ch, feat_dim)
|
||||
self.input = GinkaInput(32, 32, (13, 13), (32, 32))
|
||||
self.unet = GinkaUNet(32, base_ch, base_ch)
|
||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, stage):
|
||||
"""
|
||||
Args:
|
||||
x: 参考地图的特征向量
|
||||
Returns:
|
||||
logits: 输出logits [BS, num_classes, H, W]
|
||||
"""
|
||||
x = self.input(x)
|
||||
x = self.unet(x)
|
||||
x = self.output(x)
|
||||
x = self.output(x, stage)
|
||||
return F.softmax(x, dim=1)
|
||||
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
feat = torch.randn((1, 1024)).cuda()
|
||||
input = torch.randn((1, 32, 13, 13)).cuda()
|
||||
|
||||
# 初始化模型
|
||||
model = GinkaModel().cuda()
|
||||
@ -37,14 +39,13 @@ if __name__ == "__main__":
|
||||
print_memory("初始化后")
|
||||
|
||||
# 前向传播
|
||||
output = model(feat)
|
||||
output = model(input, 1)
|
||||
|
||||
print_memory("前向传播后")
|
||||
|
||||
print(f"输入形状: feat={feat.shape}")
|
||||
print(f"输入形状: feat={input.shape}")
|
||||
print(f"输出形状: output={output.shape}")
|
||||
# print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
||||
# print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
|
||||
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
||||
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
|
||||
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
@ -1,13 +1,42 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class GinkaOutput(nn.Module):
|
||||
def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)):
|
||||
class StageHead(nn.Module):
|
||||
def __init__(self, in_ch, out_ch, out_size=(13, 13)):
|
||||
super().__init__()
|
||||
self.conv_down = nn.Sequential(
|
||||
self.head = nn.Sequential(
|
||||
nn.Conv2d(in_ch, in_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(in_ch),
|
||||
nn.ELU(),
|
||||
|
||||
nn.Conv2d(in_ch, in_ch, 1),
|
||||
nn.InstanceNorm2d(in_ch),
|
||||
nn.ELU(),
|
||||
)
|
||||
self.pool = nn.Sequential(
|
||||
nn.AdaptiveMaxPool2d(out_size),
|
||||
nn.Conv2d(in_ch, out_ch, 1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv_down(x)
|
||||
x = self.head(x)
|
||||
x = self.pool(x)
|
||||
return x
|
||||
|
||||
class GinkaOutput(nn.Module):
|
||||
def __init__(self, in_ch=64, out_ch=32, out_size=(13, 13)):
|
||||
super().__init__()
|
||||
self.head1 = StageHead(in_ch, out_ch, out_size)
|
||||
self.head2 = StageHead(in_ch, out_ch, out_size)
|
||||
self.head3 = StageHead(in_ch, out_ch, out_size)
|
||||
|
||||
def forward(self, x, stage):
|
||||
if stage == 1:
|
||||
x = self.head1(x)
|
||||
elif stage == 2:
|
||||
x = self.head2(x)
|
||||
elif stage == 3:
|
||||
x = self.head3(x)
|
||||
else:
|
||||
raise RuntimeError("Unknown generate stage.")
|
||||
return x
|
||||
|
||||
@ -198,15 +198,15 @@ class GinkaBottleneck(nn.Module):
|
||||
return x
|
||||
|
||||
class GinkaUNet(nn.Module):
|
||||
def __init__(self, base_ch=64, out_ch=32, feat_dim=1024):
|
||||
def __init__(self, in_ch=32, base_ch=64, out_ch=32):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.input = GinkaTransformerEncoder(
|
||||
in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
|
||||
token_size=4, ff_dim=feat_dim*2, num_layers=4
|
||||
)
|
||||
self.down1 = ConvBlock(2, base_ch)
|
||||
# self.input = GinkaTransformerEncoder(
|
||||
# in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
|
||||
# token_size=4, ff_dim=feat_dim*2, num_layers=4
|
||||
# )
|
||||
self.down1 = ConvBlock(in_ch, base_ch)
|
||||
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
|
||||
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
|
||||
self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
|
||||
@ -223,10 +223,6 @@ class GinkaUNet(nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, D = x.shape # [B, 1024]
|
||||
x = x.view(B, 4, D // 4) # [B, 4, 256]
|
||||
x = self.input(x) # [B, 4, 512]
|
||||
x = x.view(B, 2, 32, 32) # [B, 2, 32, 32]
|
||||
x1 = self.down1(x) # [B, 64, 32, 32]
|
||||
x2 = self.down2(x1) # [B, 128, 16, 16]
|
||||
x3 = self.down3(x2) # [B, 256, 8, 8]
|
||||
@ -237,5 +233,6 @@ class GinkaUNet(nn.Module):
|
||||
x = self.up1(x4, x3) # [B, 256, 8, 8]
|
||||
x = self.up2(x, x2) # [B, 128, 16, 16]
|
||||
x = self.up3(x, x1) # [B, 64, 32, 32]
|
||||
x = self.final(x) # [B, 32, 32, 32]
|
||||
|
||||
return self.final(x) # [B, 32, 32, 32]
|
||||
return x
|
||||
|
||||
134
ginka/train.py
134
ginka/train.py
@ -1,134 +0,0 @@
|
||||
import os
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .model.model import GinkaModel
|
||||
from .model.loss import GinkaLoss
|
||||
from .dataset import GinkaDataset
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.args import parse_arguments
|
||||
|
||||
BATCH_SIZE = 32
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
args = parse_arguments("result/ginka.pth", "ginka-dataset.json", 'ginka-eval.json')
|
||||
|
||||
model = GinkaModel()
|
||||
model.to(device)
|
||||
minamo = MinamoModel(32)
|
||||
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||
minamo.to(device)
|
||||
minamo.eval()
|
||||
|
||||
# 准备数据集
|
||||
dataset = GinkaDataset(args.train, device, minamo)
|
||||
dataset_val = GinkaDataset(args.validate, device, minamo)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True
|
||||
)
|
||||
dataloader_val = DataLoader(
|
||||
dataset_val,
|
||||
batch_size=BATCH_SIZE,
|
||||
shuffle=True
|
||||
)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = GinkaLoss(minamo)
|
||||
|
||||
if args.resume:
|
||||
data = torch.load(args.from_state, map_location=device)
|
||||
model.load_state_dict(data["model_state"], strict=False)
|
||||
if args.load_optim:
|
||||
optimizer.load_state_dict(data["optimizer_state"])
|
||||
print("Train from loaded state.")
|
||||
|
||||
else:
|
||||
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
|
||||
criterion.weight[0] = 0.0
|
||||
|
||||
# 开始训练
|
||||
for epoch in tqdm(range(args.epochs)):
|
||||
model.train()
|
||||
total_loss = 0
|
||||
|
||||
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
|
||||
if not args.resume and epoch == 10:
|
||||
criterion.weight[0] = 0.5
|
||||
|
||||
for batch in dataloader:
|
||||
# 数据迁移到设备
|
||||
target = batch["target"].to(device)
|
||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
_, output_softmax = model(feat_vec)
|
||||
|
||||
# 计算损失
|
||||
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
|
||||
# 反向传播
|
||||
losses.backward()
|
||||
optimizer.step()
|
||||
total_loss += losses.item()
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# for name, param in model.named_parameters():
|
||||
# if param.grad is not None:
|
||||
# print(f"{name}: grad_mean={param.grad.abs().mean():.3e}, max={param.grad.abs().max():.3e}")
|
||||
|
||||
# 学习率调整
|
||||
scheduler.step()
|
||||
|
||||
if (epoch + 1) % 5 == 0:
|
||||
loss_val = 0
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for batch in dataloader_val:
|
||||
# 数据迁移到设备
|
||||
target = batch["target"].to(device)
|
||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
|
||||
# 前向传播
|
||||
output, output_softmax = model(feat_vec)
|
||||
print(torch.argmax(output, dim=1)[0])
|
||||
|
||||
# 计算损失
|
||||
losses = criterion(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
loss_val += losses.item()
|
||||
|
||||
avg_val_loss = loss_val / len(dataloader_val)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
torch.save({
|
||||
"model_state": model.state_dict(),
|
||||
# "optimizer_state": optimizer.state_dict(),
|
||||
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
|
||||
print("Train ended.")
|
||||
|
||||
torch.save({
|
||||
"model_state": model.state_dict(),
|
||||
# "optimizer_state": optimizer.state_dict(),
|
||||
}, f"result/ginka.pth")
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(4)
|
||||
train()
|
||||
@ -1,410 +0,0 @@
|
||||
import argparse
|
||||
import socket
|
||||
import struct
|
||||
import os
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
import cv2
|
||||
import numpy as np
|
||||
from .model.model import GinkaModel
|
||||
from .model.loss import GinkaLoss, WGANGinkaLoss
|
||||
from .dataset import GinkaDataset, MinamoGANDataset
|
||||
from minamo.model.model import MinamoModel
|
||||
from minamo.model.loss import MinamoLoss
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
BATCH_SIZE = 32
|
||||
EPOCHS_GINKA = 5
|
||||
EPOCHS_MINAMO = 2
|
||||
SOCKET_PATH = "./tmp/ginka_uds"
|
||||
LOSS_PATH = "result/gan/a-loss.txt"
|
||||
REPLAY_PATH = "datasets/replay.bin"
|
||||
VISION_ALPHA = 0
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||
os.makedirs("result/gan", exist_ok=True)
|
||||
os.makedirs("tmp", exist_ok=True)
|
||||
|
||||
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
||||
f.write(f"---------- {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ----------\n")
|
||||
|
||||
if not os.path.exists(REPLAY_PATH):
|
||||
with open(REPLAY_PATH, 'wb') as f:
|
||||
f.write(b'\x00\x00\x00\x00')
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--from_state", type=str, default="result/ginka.pth")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default='ginka-eval.json')
|
||||
parser.add_argument("--from_cycle", type=int, default=0)
|
||||
parser.add_argument("--to_cycle", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def parse_ginka_batch(batch):
|
||||
target = batch["target"].to(device)
|
||||
target_vision_feat = batch["target_vision_feat"].to(device).squeeze(1)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device).squeeze(1)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=1).to(device)
|
||||
|
||||
return target, target_vision_feat, target_topo_feat, feat_vec
|
||||
|
||||
def parse_minamo_batch(batch):
|
||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
||||
map1 = map1.to(device) # 转为 [B, C, H, W]
|
||||
map2 = map2.to(device)
|
||||
topo_simi = topo_simi.to(device)
|
||||
vision_simi = vision_simi.to(device)
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
return map1, map2, vision_simi, topo_simi, graph1, graph2
|
||||
|
||||
def send_all(sock, data):
|
||||
total_sent = 0
|
||||
while total_sent < len(data):
|
||||
sent = sock.send(data[total_sent:])
|
||||
if sent == 0:
|
||||
raise RuntimeError("Socket connection broken")
|
||||
total_sent += sent
|
||||
|
||||
def recv_all(sock: socket.socket, length: int):
|
||||
"""循环接收直到获得指定长度的数据"""
|
||||
data = bytes()
|
||||
while len(data) < length:
|
||||
packet = sock.recv(length - len(data)) # 只请求剩余部分
|
||||
if not packet:
|
||||
raise ConnectionError("连接中断")
|
||||
data += packet
|
||||
return data
|
||||
|
||||
def parse_minamo_data(sock: socket.socket, maps: np.ndarray):
|
||||
# 数据通讯 node 输出协议,单位字节:
|
||||
# 2 - Tensor count; 2 - Review count. Review is right behind train data;
|
||||
# 1*tc - Compare count for every map tensor delivered.
|
||||
# 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
|
||||
# N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor.
|
||||
_, _, H, W = maps.shape
|
||||
tc_buf = sock.recv(2)
|
||||
rc_buf = sock.recv(2)
|
||||
tc = struct.unpack('>h', tc_buf)[0]
|
||||
rc = struct.unpack('>h', rc_buf)[0]
|
||||
count_buf = recv_all(sock, 1 * tc)
|
||||
count: list = struct.unpack(f">{tc}b", count_buf)
|
||||
N = sum(count)
|
||||
sim_buf = recv_all(sock, 2 * 4 * (N + rc))
|
||||
com_buf = recv_all(sock, N * 1 * H * W)
|
||||
review_buf = recv_all(sock, rc * 2 * H * W) if rc > 0 else bytes()
|
||||
|
||||
sim = struct.unpack(f">{(N + rc) * 2}f", sim_buf)
|
||||
com = struct.unpack(f">{N * 1 * H * W}b", com_buf)
|
||||
review = struct.unpack(f">{rc * 2 * H * W}", review_buf) if rc > 0 else list()
|
||||
|
||||
res = list()
|
||||
flatten_idx = 0
|
||||
# 读取当前这一轮生成器的数据
|
||||
for idx in range(tc):
|
||||
com_count = count[idx]
|
||||
for i in range(com_count):
|
||||
com_start = flatten_idx * H * W
|
||||
com_end = (flatten_idx + 1) * H * W
|
||||
vis_sim = sim[flatten_idx * 2]
|
||||
topo_sim = sim[flatten_idx * 2 + 1]
|
||||
com_data = com[com_start:com_end]
|
||||
flatten_idx += 1
|
||||
com_map = np.array(com_data, dtype=np.int8).reshape(H, W)
|
||||
# map1, map2, vision_similarity, topo_similarity, is_review
|
||||
res.append((maps[idx], com_map, vis_sim, topo_sim, False))
|
||||
|
||||
return res
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
ginka = GinkaModel()
|
||||
ginka.to(device)
|
||||
minamo = MinamoModel(32)
|
||||
minamo.load_state_dict(torch.load("result/minamo.pth", map_location=device)["model_state"])
|
||||
minamo.to(device)
|
||||
minamo.eval()
|
||||
|
||||
# 准备数据集
|
||||
ginka_dataset = GinkaDataset(args.train, device, minamo)
|
||||
ginka_dataset_val = GinkaDataset(args.validate, device, minamo)
|
||||
minamo_dataset = MinamoGANDataset("datasets/minamo-dataset-1.json")
|
||||
minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json")
|
||||
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE // 2, shuffle=True)
|
||||
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE // 2, shuffle=True)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion_ginka = GinkaLoss(minamo)
|
||||
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9))
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=EPOCHS_MINAMO, T_mult=2, eta_min=1e-6)
|
||||
criterion_minamo = MinamoLoss()
|
||||
|
||||
criterion = WGANGinkaLoss()
|
||||
|
||||
# 用于生成图片
|
||||
tile_dict = dict()
|
||||
for file in os.listdir('tiles'):
|
||||
name = os.path.splitext(file)[0]
|
||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||
|
||||
# 与 node 端通讯
|
||||
if os.path.exists(SOCKET_PATH):
|
||||
os.remove(SOCKET_PATH)
|
||||
server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
server.bind(SOCKET_PATH)
|
||||
server.listen(1)
|
||||
|
||||
if args.resume:
|
||||
data = torch.load(args.from_state, map_location=device)
|
||||
ginka.load_state_dict(data["model_state"], strict=False)
|
||||
print("Train from loaded state.")
|
||||
|
||||
print("Waiting for client connection...")
|
||||
conn, _ = server.accept()
|
||||
print("Client connected.")
|
||||
|
||||
for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"):
|
||||
# -------------------- 训练生成器
|
||||
for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model", leave=False):
|
||||
ginka.train()
|
||||
minamo.eval()
|
||||
total_loss = 0
|
||||
|
||||
for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"):
|
||||
# 数据迁移到设备
|
||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||
# 前向传播
|
||||
optimizer_ginka.zero_grad()
|
||||
_, output_softmax = ginka(feat_vec)
|
||||
# 计算损失
|
||||
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
# 反向传播
|
||||
losses.backward()
|
||||
optimizer_ginka.step()
|
||||
total_loss += losses.item()
|
||||
|
||||
avg_loss = total_loss / len(ginka_dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {avg_loss:.6f} | lr: {(optimizer_ginka.param_groups[0]['lr']):.6f}")
|
||||
|
||||
# 学习率调整
|
||||
scheduler_ginka.step(epoch + 1)
|
||||
|
||||
if (epoch + 1) % 5 == 0:
|
||||
loss_val = 0
|
||||
ginka.eval()
|
||||
idx = 0
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(ginka_dataloader_val, leave=False, desc="Validating Ginka Model"):
|
||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||
output, output_softmax = ginka(feat_vec)
|
||||
losses = criterion_ginka(output_softmax, target, target_vision_feat, target_topo_feat)
|
||||
loss_val += losses.item()
|
||||
if epoch + 1 == EPOCHS_GINKA:
|
||||
# 最后一次验证的时候顺带生成图片
|
||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
||||
for matrix in map_matrix:
|
||||
image = matrix_to_image_cv(matrix, tile_dict)
|
||||
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
||||
idx += 1
|
||||
|
||||
avg_val_loss = loss_val / len(ginka_dataloader_val)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
# 使用训练集生成 minamo 训练数据,更准确
|
||||
gen_list: np.ndarray = np.empty((0, 13, 13), np.int8)
|
||||
prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32)
|
||||
with torch.no_grad():
|
||||
for batch in ginka_dataloader:
|
||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||
output, output_softmax = ginka(feat_vec)
|
||||
prob = output_softmax.cpu().numpy()
|
||||
prob_list = np.concatenate((prob_list, prob), axis=0)
|
||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
||||
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
|
||||
|
||||
tqdm.write(f"Cycle {cycle} Ginka train ended.")
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
}, f"result/gan/ginka-{cycle}.pth")
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
}, f"result/ginka.pth")
|
||||
|
||||
# -------------------- 生成 Minamo 的训练数据
|
||||
|
||||
# 数据通讯 python 输出协议,单位字节:
|
||||
# 2 - Tensor count; 1 - Map height; 1 - Map Width; N*1*H*W - Map tensor, int8 type.
|
||||
N, H, W = gen_list.shape
|
||||
gen_bytes = gen_list.astype(np.int8).tobytes()
|
||||
buf = bytearray()
|
||||
buf.extend(struct.pack('>h', N)) # Tensor count
|
||||
buf.extend(struct.pack('>b', H)) # Map height
|
||||
buf.extend(struct.pack('>b', W)) # Map width
|
||||
buf.extend(gen_bytes) # Map tensor
|
||||
conn.sendall(buf)
|
||||
data = parse_minamo_data(conn, prob_list)
|
||||
|
||||
vis_sim = 0
|
||||
topo_sim = 0
|
||||
for _, _, vis, topo, _ in data:
|
||||
vis_sim += vis
|
||||
topo_sim += topo
|
||||
|
||||
vis_sim /= len(data)
|
||||
topo_sim /= len(data)
|
||||
|
||||
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
||||
f.write(f'Cycle {cycle} | Ginka Vision Similarity: {vis_sim:.12f} | Ginka Topo Similarity: {topo_sim:.12f} | Ginka Loss: {avg_val_loss:.12f}')
|
||||
|
||||
# 经验回放部分
|
||||
with open(REPLAY_PATH, 'r+b') as f:
|
||||
# 读取文件开头获取总长度
|
||||
f.seek(0)
|
||||
count = struct.unpack('>i', f.read(4))[0] # 取出整数
|
||||
if count > 0:
|
||||
replay = np.random.choice(count, size=min(count, len(data) // 4), replace=False)
|
||||
|
||||
replay_data = np.empty((len(replay), 32, 13, 13))
|
||||
for i, n in enumerate(replay):
|
||||
f.seek(n * 32 * 13 * 13 + 4)
|
||||
arr = np.frombuffer(f.read(32 * 13 * 13 * 4), dtype=np.float32).reshape(32, 13, 13)
|
||||
replay_data[i] = arr
|
||||
|
||||
map_data: np.ndarray = replay_data.argmax(axis=1)
|
||||
buf = bytearray()
|
||||
buf.extend(struct.pack('>h', len(replay))) # Tensor count
|
||||
buf.extend(struct.pack('>b', H)) # Map height
|
||||
buf.extend(struct.pack('>b', W)) # Map width
|
||||
buf.extend(map_data.astype(np.int8).tobytes()) # Map tensor
|
||||
conn.sendall(buf)
|
||||
data.extend(parse_minamo_data(conn, replay_data))
|
||||
|
||||
# 把新的内容写入文件末尾
|
||||
to_write = np.random.choice(N, size=min(N, 100), replace=False)
|
||||
write_data = bytearray()
|
||||
for n in to_write:
|
||||
write_data.extend(prob_list[n].tobytes())
|
||||
|
||||
f.seek(0, 2) # 定位到文件末尾
|
||||
f.write(write_data)
|
||||
|
||||
f.seek(0) # 定位到文件开头
|
||||
f.write(struct.pack('>i', count + len(to_write)))
|
||||
f.flush() # 确保数据被刷新到磁盘
|
||||
|
||||
minamo_dataset.set_data(data)
|
||||
|
||||
# -------------------- 训练判别器
|
||||
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
|
||||
ginka.eval()
|
||||
minamo.train()
|
||||
total_loss = 0
|
||||
|
||||
for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"):
|
||||
map1, map2, vis_sim, topo_sim, graph1, graph2 = parse_minamo_batch(batch)
|
||||
batch_size = map1.shape[0]
|
||||
|
||||
if batch_size == 1:
|
||||
continue
|
||||
|
||||
# 前向传播
|
||||
optimizer_minamo.zero_grad()
|
||||
vis_feat_real, topo_feat_real = minamo(map1, graph1)
|
||||
vis_feat_ref, topo_feat_ref = minamo(map2, graph2)
|
||||
|
||||
# 生成假数据
|
||||
with torch.no_grad():
|
||||
fake_feat = torch.randn((batch_size, 1024), device=device)
|
||||
fake_data = ginka(fake_feat)
|
||||
|
||||
# 创建插值样本
|
||||
alpha = torch.rand((batch_size, 1, 1, 1), device=device)
|
||||
interpolates = (alpha * map2 + (1 - alpha) * fake_data).requires_grad_(True)
|
||||
|
||||
vis_feat_fake, topo_feat_fake = minamo(fake_data)
|
||||
vis_feat_interp, topo_feat_interp = minamo(interpolates)
|
||||
|
||||
vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1)
|
||||
topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1)
|
||||
vis_pred_fake = F.cosine_similarity(vis_feat_fake, vis_feat_ref, dim=1).unsqueeze(-1)
|
||||
topo_pred_fake = F.cosine_similarity(topo_feat_fake, topo_feat_ref, dim=1).unsqueeze(-1)
|
||||
vis_pred_interp = F.cosine_similarity(vis_feat_interp, vis_feat_ref, dim=1).unsqueeze(-1)
|
||||
topo_pred_interp = F.cosine_similarity(topo_feat_interp, topo_feat_ref, dim=1).unsqueeze(-1)
|
||||
|
||||
# 计算相似度
|
||||
score_real = F.l1_loss(vis_pred_real, vis_sim) * VISION_ALPHA + F.l1_loss(topo_pred_real, topo_sim) * (1 - VISION_ALPHA)
|
||||
score_fake = vis_pred_fake * VISION_ALPHA + topo_pred_fake * (1 - VISION_ALPHA)
|
||||
score_interp = vis_pred_interp * VISION_ALPHA + topo_pred_interp * (1 - VISION_ALPHA)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion.discriminator_loss(score_real, score_fake, score_interp)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
optimizer_minamo.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
ave_loss = total_loss / len(minamo_dataloader)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | loss: {ave_loss:.6f} | lr: {(optimizer_minamo.param_groups[0]['lr']):.6f}")
|
||||
|
||||
scheduler_minamo.step(epoch + 1)
|
||||
|
||||
# 每十轮推理一次验证集
|
||||
if epoch + 1 == EPOCHS_MINAMO:
|
||||
minamo.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"):
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
|
||||
|
||||
vis_feat_real, topo_feat_real = minamo(map1_val, graph1)
|
||||
vis_feat_ref, topo_feat_ref = minamo(map2_val, graph2)
|
||||
|
||||
vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1)
|
||||
topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1)
|
||||
|
||||
# 计算损失
|
||||
loss_val = criterion_minamo(vis_pred_real, topo_pred_real, vision_simi_val, topo_simi_val)
|
||||
val_loss += loss_val.item()
|
||||
|
||||
avg_val_loss = val_loss / len(minamo_dataloader_val)
|
||||
tqdm.write(f"[INFO {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Validation::loss: {avg_val_loss:.6f}")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
tqdm.write(f"Cycle {cycle} Minamo train ended.")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
}, f"result/gan/minamo-{cycle}.pth")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
}, f"result/minamo.pth")
|
||||
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
||||
f.write(f' | Minamo: {avg_val_loss:.12f}\n')
|
||||
|
||||
print("Train ended.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(4)
|
||||
train()
|
||||
@ -16,7 +16,7 @@ from shared.graph import batch_convert_soft_map_to_graph
|
||||
from shared.image import matrix_to_image_cv
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
|
||||
BATCH_SIZE = 32
|
||||
BATCH_SIZE = 16
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
@ -30,39 +30,56 @@ def parse_arguments():
|
||||
parser.add_argument("--state_ginka", type=str, default="result/wgan/ginka-100.pth")
|
||||
parser.add_argument("--state_minamo", type=str, default="result/wgan/minamo-100.pth")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default="ginka-eval.json")
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--checkpoint", type=int, default=5)
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def clip_weights(model, clip_value=0.01):
|
||||
for param in model.parameters():
|
||||
param.data = torch.clamp(param.data, -clip_value, clip_value)
|
||||
def gen_curriculum(gen, masked1, masked2, masked3, detach=False) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
fake1: torch.Tensor = gen(masked1, 1)
|
||||
fake2: torch.Tensor = gen(masked2, 2)
|
||||
fake3: torch.Tensor = gen(masked3, 3)
|
||||
if detach:
|
||||
return fake1.detach(), fake2.detach(), fake3.detach()
|
||||
else:
|
||||
return fake1, fake2, fake3
|
||||
|
||||
def gen_total(gen, input, detach=False) -> torch.Tensor:
|
||||
fake1 = gen(input, 1)
|
||||
fake2 = gen(fake1, 2)
|
||||
fake3 = gen(fake2, 3)
|
||||
if detach:
|
||||
return fake3.detach()
|
||||
else:
|
||||
return fake3
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
# c_steps = 1 if args.resume else 5
|
||||
# g_steps = 5 if args.resume else 1
|
||||
c_steps = 5
|
||||
g_steps = 1
|
||||
# 1 代表课程学习阶段,2 代表课程学习后,逐渐转为联合学习的阶段
|
||||
# 3 代表课程学习后的联合遮挡学习阶段,4 代表最后随机输入的联合学习阶段
|
||||
train_stage = 1
|
||||
mask_ratio = 0.1 # 蒙版区域大小,每次增加 0.1,到达 0.9 之后进入阶段 2 的训练
|
||||
random_ratio = 0
|
||||
|
||||
ginka = GinkaModel()
|
||||
minamo = MinamoScoreModule()
|
||||
minamo_sim = MinamoSimilarityModel()
|
||||
ginka.to(device)
|
||||
minamo.to(device)
|
||||
minamo_sim.to(device)
|
||||
|
||||
dataset = GinkaWGANDataset(args.train, device)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
|
||||
dataset_val = GinkaWGANDataset(args.validate, device)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=1e-5, betas=(0.0, 0.9))
|
||||
optimizer_minamo_sim = optim.Adam(minamo_sim.parameters(), lr=1e-4)
|
||||
|
||||
# scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=args.epochs)
|
||||
# scheduler_minamo = optim.lr_scheduler.CosineAnnealingLR(optimizer_minamo, T_max=args.epochs)
|
||||
@ -82,87 +99,134 @@ def train():
|
||||
ginka.load_state_dict(data_ginka["model_state"], strict=False)
|
||||
minamo.load_state_dict(data_minamo["model_state"], strict=False)
|
||||
|
||||
if data_ginka["c_steps"] is not None and data_ginka["g_steps"] is not None:
|
||||
if data_ginka.get("c_steps") is not None and data_ginka.get("g_steps") is not None:
|
||||
c_steps = data_ginka["c_steps"]
|
||||
g_steps = data_ginka["g_steps"]
|
||||
|
||||
if data_ginka.get("mask_ratio") is not None:
|
||||
mask_ratio = data_ginka["mask_ratio"]
|
||||
|
||||
if data_ginka.get("random_ratio") is not None:
|
||||
random_ratio = data_ginka["random_ratio"]
|
||||
|
||||
if data_ginka.get("stage") is not None:
|
||||
train_stage = data_ginka["stage"]
|
||||
|
||||
if args.load_optim:
|
||||
if data_ginka["optim_state"] is not None:
|
||||
if data_ginka.get("optim_state") is not None:
|
||||
optimizer_ginka.load_state_dict(data_ginka["optim_state"])
|
||||
if data_minamo["optim_state"] is not None:
|
||||
if data_minamo.get("optim_state") is not None:
|
||||
optimizer_minamo.load_state_dict(data_minamo["optim_state"])
|
||||
if data_minamo["optim_state_sim"] is not None:
|
||||
optimizer_minamo_sim.load_state_dict(data_minamo["optim_state_sim"])
|
||||
|
||||
dataset.train_stage = train_stage
|
||||
dataset.mask_ratio1 = mask_ratio
|
||||
dataset.mask_ratio2 = mask_ratio
|
||||
dataset.mask_ratio3 = mask_ratio
|
||||
dataset.random_ratio = random_ratio
|
||||
|
||||
dataset_val.train_stage = train_stage
|
||||
dataset_val.mask_ratio1 = mask_ratio
|
||||
dataset_val.mask_ratio2 = mask_ratio
|
||||
dataset_val.mask_ratio3 = mask_ratio
|
||||
dataset_val.random_ratio = random_ratio
|
||||
|
||||
print("Train from loaded state.")
|
||||
|
||||
low_loss_epochs = 0
|
||||
|
||||
for epoch in tqdm(range(args.epochs), desc="WGAN Training", disable=disable_tqdm):
|
||||
loss_total_minamo = torch.Tensor([0]).to(device)
|
||||
loss_total_minamo_sim = torch.Tensor([0]).to(device)
|
||||
loss_total_ginka = torch.Tensor([0]).to(device)
|
||||
dis_total = torch.Tensor([0]).to(device)
|
||||
loss_ce_total = torch.Tensor([0]).to(device)
|
||||
|
||||
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
batch_size = real_data.size(0)
|
||||
real_data = real_data.to(device)
|
||||
real_graph = batch_convert_soft_map_to_graph(real_data)
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
|
||||
|
||||
# ---------- 训练判别器
|
||||
for _ in range(c_steps):
|
||||
# 生成假样本
|
||||
optimizer_minamo.zero_grad()
|
||||
z = torch.rand(batch_size, 1024, device=device)
|
||||
fake_data = ginka(z)
|
||||
fake_data = fake_data.detach()
|
||||
|
||||
# 计算判别器输出
|
||||
# 反向传播
|
||||
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
|
||||
loss_d.backward()
|
||||
# torch.nn.utils.clip_grad_norm_(minamo.parameters(), max_norm=2.0)
|
||||
# total_norm = torch.linalg.vector_norm(torch.stack([torch.linalg.vector_norm(p.grad) for p in minamo.topo_model.parameters()]), 2)
|
||||
# print("Critic 梯度范数:", total_norm.item())
|
||||
# print("Critic 输入范围:", fake_data.min().item(), fake_data.max().item(), real_data.min().item(), real_data.max().item())
|
||||
# print("Critic 输出范围:", d_real.min().item(), d_real.max().item())
|
||||
optimizer_ginka.zero_grad()
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||
|
||||
loss_d1, dis1 = criterion.discriminator_loss(minamo, 1, real1, fake1)
|
||||
loss_d2, dis2 = criterion.discriminator_loss(minamo, 2, real2, fake2)
|
||||
loss_d3, dis3 = criterion.discriminator_loss(minamo, 3, real3, fake3)
|
||||
|
||||
dis_avg = (dis1 + dis2 + dis3) / 3.0
|
||||
loss_d_avg = (loss_d1 + loss_d2 + loss_d3) / 3.0
|
||||
|
||||
# 反向传播
|
||||
loss_d_avg.backward()
|
||||
elif train_stage == 3:
|
||||
pass
|
||||
|
||||
optimizer_minamo.step()
|
||||
|
||||
loss_total_minamo += loss_d.detach()
|
||||
dis_total += dis.detach()
|
||||
loss_total_minamo += loss_d_avg.detach()
|
||||
dis_total += dis_avg.detach()
|
||||
|
||||
# ---------- 训练生成器
|
||||
|
||||
for _ in range(g_steps):
|
||||
optimizer_minamo.zero_grad()
|
||||
optimizer_ginka.zero_grad()
|
||||
# optimizer_minamo_sim.zero_grad()
|
||||
|
||||
z1 = torch.randn(batch_size, 1024, device=device)
|
||||
z2 = torch.randn(batch_size, 1024, device=device)
|
||||
fake_softmax1, fake_softmax2 = ginka(z1), ginka(z2)
|
||||
|
||||
# 先训练辅助判别器
|
||||
# loss_c_assist = criterion.discriminator_loss_assist2(minamo_sim, real_data, fake_softmax1, fake_softmax2)
|
||||
# loss_c_assist.backward(retain_graph=True)
|
||||
# optimizer_minamo_sim.step()
|
||||
|
||||
loss_g = criterion.generator_loss(minamo, minamo_sim, fake_softmax1, fake_softmax2)
|
||||
loss_g.backward()
|
||||
optimizer_ginka.step()
|
||||
|
||||
loss_total_ginka += loss_g
|
||||
# loss_total_minamo_sim += loss_c_assist.detach()
|
||||
# tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}")
|
||||
if train_stage == 1 or train_stage == 2:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, False)
|
||||
|
||||
loss_g1, _, loss_ce_g1, _ = criterion.generator_loss(minamo, 1, mask_ratio, real1, fake1, masked1)
|
||||
loss_g2, _, loss_ce_g2, _ = criterion.generator_loss(minamo, 2, mask_ratio, real2, fake2, masked2)
|
||||
loss_g3, _, loss_ce_g3, _ = criterion.generator_loss(minamo, 3, mask_ratio, real3, fake3, masked3)
|
||||
|
||||
loss_g = (loss_g1 + loss_g2 + loss_g3) / 3.0
|
||||
loss_ce = max(loss_ce_g1, loss_ce_g2, loss_ce_g3)
|
||||
|
||||
loss_g.backward()
|
||||
optimizer_ginka.step()
|
||||
loss_total_ginka += loss_g.detach()
|
||||
loss_ce_total += loss_ce.detach()
|
||||
|
||||
elif train_stage == 3:
|
||||
pass
|
||||
|
||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
||||
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
||||
avg_loss_minamo_sim = loss_total_minamo_sim.item() / len(dataloader) / g_steps
|
||||
avg_loss_ce = loss_ce_total.item() / len(dataloader) / g_steps
|
||||
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +\
|
||||
f"Epoch: {epoch + 1} | W Loss: {avg_dis:.8f} | " +\
|
||||
f"G Loss: {avg_loss_ginka:.8f} | D Loss: {avg_loss_minamo:.8f} | " +\
|
||||
f"lr G: {(optimizer_ginka.param_groups[0]['lr']):.8f}"
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " +
|
||||
f"Epoch: {epoch + 1} | W: {avg_dis:.8f} | " +
|
||||
f"G: {avg_loss_ginka:.8f} | D: {avg_loss_minamo:.8f} | " +
|
||||
f"CE: {avg_loss_ce:.8f} | Mask: {mask_ratio:.2f}"
|
||||
)
|
||||
|
||||
if avg_loss_ce < 0.5:
|
||||
low_loss_epochs += 1
|
||||
else:
|
||||
low_loss_epochs = 0
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 2:
|
||||
random_ratio += 0.1
|
||||
random_ratio = min(random_ratio, 0.5)
|
||||
low_loss_epochs = 0
|
||||
|
||||
if low_loss_epochs >= 5 and train_stage == 1:
|
||||
if mask_ratio >= 0.9:
|
||||
train_stage = 2
|
||||
|
||||
mask_ratio += 0.1
|
||||
mask_ratio = min(mask_ratio, 0.9)
|
||||
low_loss_epochs = 0
|
||||
|
||||
dataset.train_stage = 2
|
||||
dataset_val.train_stage = 2
|
||||
dataset.random_ratio = random_ratio
|
||||
dataset_val.random_ratio = random_ratio
|
||||
dataset.mask_ratio1 = dataset.mask_ratio2 = dataset.mask_ratio3 = mask_ratio
|
||||
dataset_val.mask_ratio1 = dataset_val.mask_ratio2 = dataset_val.mask_ratio3 = mask_ratio
|
||||
|
||||
# scheduler_ginka.step()
|
||||
# scheduler_minamo.step()
|
||||
|
||||
@ -172,38 +236,44 @@ def train():
|
||||
g_steps = 1
|
||||
|
||||
if avg_loss_ginka > 0 or avg_loss_minamo > 0:
|
||||
c_steps = min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15)
|
||||
c_steps = int(max(min(5 + (avg_loss_ginka + avg_loss_minamo) * 5, 15), 1))
|
||||
else:
|
||||
c_steps = 5
|
||||
|
||||
# 每若干轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % args.checkpoint == 0:
|
||||
# 输出 20 张图片,每批次 4 张,一共五批
|
||||
idx = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(5):
|
||||
z = torch.randn(4, 1024, device=device)
|
||||
output = ginka(z)
|
||||
|
||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
||||
for matrix in map_matrix:
|
||||
image = matrix_to_image_cv(matrix, tile_dict)
|
||||
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
||||
idx += 1
|
||||
|
||||
# 保存检查点
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict(),
|
||||
"optim_state": optimizer_ginka.state_dict(),
|
||||
"c_steps": c_steps,
|
||||
"g_steps": g_steps
|
||||
"g_steps": g_steps,
|
||||
"stage": train_stage,
|
||||
"mask_ratio": mask_ratio,
|
||||
"random_ratio": random_ratio,
|
||||
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict(),
|
||||
"model_state_sim": minamo_sim.state_dict(),
|
||||
"optim_state": optimizer_minamo.state_dict(),
|
||||
"optim_state_sim": optimizer_minamo_sim.state_dict()
|
||||
"optim_state": optimizer_minamo.state_dict()
|
||||
}, f"result/wgan/minamo-{epoch + 1}.pth")
|
||||
|
||||
idx = 0
|
||||
with torch.no_grad():
|
||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
real1, masked1, real2, masked2, real3, masked3 = [item.to(device) for item in batch]
|
||||
if train_stage == 1:
|
||||
fake1, fake2, fake3 = gen_curriculum(ginka, masked1, masked2, masked3, True)
|
||||
fake1 = torch.argmax(fake1, dim=1).cpu().numpy()
|
||||
fake2 = torch.argmax(fake2, dim=1).cpu().numpy()
|
||||
fake3 = torch.argmax(fake3, dim=1).cpu().numpy()
|
||||
|
||||
for i in range(fake1.shape[0]):
|
||||
for key, one in enumerate([fake1, fake2, fake3]):
|
||||
map_matrix = one[i]
|
||||
image = matrix_to_image_cv(map_matrix, tile_dict)
|
||||
cv2.imwrite(f"result/ginka_img/{idx}_{key}.png", image)
|
||||
|
||||
idx += 1
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
@ -211,7 +281,6 @@ def train():
|
||||
}, f"result/ginka.pth")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict(),
|
||||
"model_state_sim": minamo_sim.state_dict(),
|
||||
}, f"result/minamo.pth")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -20,23 +20,41 @@ class MinamoModel(nn.Module):
|
||||
|
||||
return vision_feat, topo_feat
|
||||
|
||||
class MinamoScoreHead(nn.Module):
|
||||
def __init__(self, in_dim, out_dim):
|
||||
super().__init__()
|
||||
self.vision_fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_dim, out_dim)),
|
||||
)
|
||||
self.topo_fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_dim, out_dim))
|
||||
)
|
||||
|
||||
def forward(self, vis_feat, topo_feat):
|
||||
vis_score = self.vision_fc(vis_feat)
|
||||
topo_score = self.topo_fc(topo_feat)
|
||||
return vis_score, topo_score
|
||||
|
||||
class MinamoScoreModule(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
super().__init__()
|
||||
self.topo_model = MinamoTopoModel(tile_types)
|
||||
self.vision_model = MinamoVisionModel(tile_types)
|
||||
# 输出层
|
||||
self.topo_fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(512, 1)),
|
||||
)
|
||||
self.vision_fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(512, 1)),
|
||||
)
|
||||
self.head1 = MinamoScoreHead(512, 1)
|
||||
self.head2 = MinamoScoreHead(512, 1)
|
||||
self.head3 = MinamoScoreHead(512, 1)
|
||||
|
||||
def forward(self, map, graph):
|
||||
topo_feat = self.topo_model(graph)
|
||||
topo_score = self.topo_fc(topo_feat)
|
||||
def forward(self, map, graph, stage):
|
||||
vision_feat = self.vision_model(map)
|
||||
vision_score = self.vision_fc(vision_feat)
|
||||
topo_feat = self.topo_model(graph)
|
||||
if stage == 1:
|
||||
vision_score, topo_score = self.head1(vision_feat, topo_feat)
|
||||
elif stage == 2:
|
||||
vision_score, topo_score = self.head2(vision_feat, topo_feat)
|
||||
elif stage == 3:
|
||||
vision_score, topo_score = self.head3(vision_feat, topo_feat)
|
||||
else:
|
||||
raise RuntimeError("Unknown critic stage.")
|
||||
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
|
||||
return score, vision_score, topo_score
|
||||
|
||||
Loading…
Reference in New Issue
Block a user