refactor: 改为 Wasserstein GAN

This commit is contained in:
unanmed 2025-04-06 18:44:18 +08:00
parent d7209a68a2
commit 29cfb4d029
14 changed files with 643 additions and 183 deletions

View File

@ -4,7 +4,7 @@ import { MinamoTrainData } from './types';
import { generateTrainData } from './process/minamo';
const SOCKET_FILE = '../tmp/ginka_uds';
const [refer] = process.argv.slice(2);
const [refer, replayPath = '../datasets/replay.bin'] = process.argv.slice(2);
let id = 0;
@ -43,7 +43,7 @@ function generateGANData(
const size2: [number, number] = [map[0].length, map.length];
if (size1[0] !== size2[0] || size1[1] !== size2[1]) return [];
return generateTrainData(v, id2, floor.map, map, size1);
return generateTrainData(v, id2, floor.map, map, size1, false, false, false);
});
return data.flat();
}
@ -104,7 +104,7 @@ class DataReceiver {
console.log(`UDS IPC connected successfully.`);
});
client.on('data', buffer => {
client.on('data', async buffer => {
const data = DataReceiver.check(buffer);
if (!data) return;
@ -112,20 +112,19 @@ class DataReceiver {
const simData = map.map(v => generateGANData(keys, referTower, v));
const rc = 0;
const compareData = simData.flat();
const reviewData: MinamoTrainData[] = [];
// 数据通讯 node 输出协议,单位字节:
// 2 - Tensor count; 2 - Review count. Review is right behind train data;
// 2 - Tensor count; 2 - Replay count. Replay is right behind train data;
// 1*tc - Compare count for every map tensor delivered.
// 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
// N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor.
// N*1*H*W - Compare map for every map tensor. rc*2*H*W - Replay map tensor.
const toSend = Buffer.alloc(
2 + // Tensor count
2 + // Review count
2 + // Replay count
1 * count + // Compare count
2 * 4 * (compareData.length + rc) + // Similarity data
compareData.length * 1 * h * w + // Compare map
rc * 2 * h * w, // Review map
rc * 2 * h * w, // Replay map
0
);
console.log(
@ -141,7 +140,7 @@ class DataReceiver {
let offset = 0;
toSend.writeInt16BE(count); // Tensor count
toSend.writeInt16BE(0, 2); // Review count
toSend.writeInt16BE(0, 2); // Replay count
offset += 2 + 2;
// Compare count
toSend.set(
@ -164,13 +163,6 @@ class DataReceiver {
offset // Set from Compare map
);
offset += compareData.length * 1 * h * w;
if (reviewData.length > 0) {
// Review map
toSend.set(
new Uint8Array(reviewData.map(v => [v.map1, v.map2]).flat(4)),
offset // Set from last chunk
);
}
client.write(toSend);
});

View File

@ -192,7 +192,10 @@ export function generateTrainData(
id2: string,
map1: number[][],
map2: number[][],
size: [number, number]
size: [number, number],
hasSelf: boolean = true,
hasTransform: boolean = true,
hasSimilar: boolean = true
) {
const topoSimilarity = compareMap(id1, id2, map1, map2);
const visionSimilarity = calculateVisualSimilarity(map1, map2);
@ -205,39 +208,47 @@ export function generateTrainData(
};
const data: MinamoTrainData[] = [];
data.push(train);
// 自身与自身对比的训练集,保证模型对相同地图输出 1
const self1 = `${id1}:${id1}`;
const self2 = `${id2}:${id2}`;
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
if (selfTrain.includes(self1)) {
const selfTrain1: MinamoTrainData = {
map1: map1,
map2: map1,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data.push(selfTrain1);
if (hasSelf) {
// 自身与自身对比的训练集,保证模型对相同地图输出 1
const self1 = `${id1}:${id1}`;
const self2 = `${id2}:${id2}`;
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
if (selfTrain.includes(self1)) {
const selfTrain1: MinamoTrainData = {
map1: map1,
map2: map1,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data.push(selfTrain1);
}
if (selfTrain.includes(self2)) {
const selfTrain2: MinamoTrainData = {
map1: map2,
map2: map2,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data.push(selfTrain2);
}
}
if (selfTrain.includes(self2)) {
const selfTrain2: MinamoTrainData = {
map1: map2,
map2: map2,
topoSimilarity: 1,
visionSimilarity: 1,
size: size
};
data.push(selfTrain2);
if (hasTransform) {
const transform = generateTransformData(
id1,
id2,
map1,
map2,
topoSimilarity
);
data.push(...transform.map(v => v[1]))
}
const transform = generateTransformData(
id1,
id2,
map1,
map2,
topoSimilarity
);
const similar = generateSimilarData(id1, map1);
return [...data, ...transform.map(v => v[1]), ...similar.map(v => v[1])];
if (hasSimilar) {
const similar = generateSimilarData(id1, map1);
data.push(...similar.map(v => v[1]))
}
return data;
}
export function generatePair(

View File

@ -87,8 +87,7 @@ function weisfeilerLehmanIteration(
});
weight *= decay;
});
// 把每个节点的原始标签也加上权重使用最远权重再衰减1次可以认为是资源重复率
weight *= decay;
// 把每个节点的原始标签也加上,权重使用最远权重,可以认为是资源重复率
nodes.forEach(node => {
if (!numMap.has(node.originalLabel)) {
numMap.set(node.originalLabel, weight);

View File

@ -51,6 +51,25 @@ class GinkaDataset(Dataset):
"target": target,
}
class GinkaWGANDataset(Dataset):
def __init__(self, data_path: str, device):
self.data = load_data(data_path) # 自定义数据加载函数
self.device = device
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
min_main = random.uniform(0.75, 0.9)
max_main = random.uniform(0.9, 1)
epsilon = random.uniform(0, 0.25)
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device)
return target_smooth
class MinamoGANDataset(Dataset):
def __init__(self, refer_data_path):
self.refer = load_minamo_gan_data(load_data(refer_data_path))
@ -67,6 +86,7 @@ class MinamoGANDataset(Dataset):
return len(self.data)
def __getitem__(self, idx):
# 假定 map2 是参考地图
item = self.data[idx]
map1, map2, vis_sim, topo_sim, review = item

View File

@ -3,8 +3,10 @@ from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from minamo.model.model import MinamoModel
from shared.graph import batch_convert_soft_map_to_graph
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
CLASS_NUM = 32
ILLEGAL_MAX_NUM = 12
@ -260,10 +262,129 @@ class GinkaLoss(nn.Module):
)
losses = [
minamo_loss * self.weight[0],
minamo_loss * self.weight[0] * 4,
class_loss * self.weight[1],
entrance_loss * self.weight[2],
count_loss * self.weight[3]
]
return sum(losses)
# 对图像数据进行插值
def interpolate_data(real_data, fake_data, epsilon):
return epsilon * real_data + (1 - epsilon) * fake_data
# 对节点特征进行插值,但保持边连接关系不变
def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5):
# 插值节点特征
x_real, x_fake = real_graph.x, fake_graph.x
x_interp = epsilon * x_real + (1 - epsilon) * x_fake
# 保持边连接关系和边特征不变
edge_index_interp = real_graph.edge_index # 保持边连接关系
edge_attr_interp = real_graph.edge_attr # 如果有边特征,保持不变
return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp)
def js_divergence(P, Q, epsilon=1e-10):
"""
输入:
P, Q: [B, C, H, W], 已通过 Softmax 处理
输出:
JS 散度标量全局平均
"""
# 转换为 [B, H, W, C] 以便在最后一维计算概率分布
P = P.permute(0, 2, 3, 1) # [B, H, W, C]
Q = Q.permute(0, 2, 3, 1)
# 平均分布 M = (P + Q)/2
M = 0.5 * (P + Q)
# 计算 KL(P||M) 和 KL(Q||M)
kl_pm = F.kl_div(torch.log(M + epsilon), P, reduction='none', log_target=False).sum(dim=-1) # [B, H, W]
kl_qm = F.kl_div(torch.log(M + epsilon), Q, reduction='none', log_target=False).sum(dim=-1) # [B, H, W]
# JS 散度 = 0.5*(KL(P||M) + KL(Q||M))
js = 0.5 * (kl_pm + kl_qm)
# 全局平均(可替换为其他聚合方式)
return js.mean() # 标量
class WGANGinkaLoss:
def __init__(self, lambda_gp=10, weight=[0.7, 0.2, 0.1], diversity_lamda=0):
self.lambda_gp = lambda_gp # 梯度惩罚系数
self.weight = weight
self.diversity_lamda = diversity_lamda
def compute_gradient_penalty(self, critic, real_data, fake_data):
# 进行插值
batch_size = real_data.size(0)
epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device)
interp_data = interpolate_data(real_data, fake_data, epsilon_data)
interp_graph = batch_convert_soft_map_to_graph(interp_data)
# 对图像进行反向传播并计算梯度
interp_data.requires_grad_()
interp_graph.x.requires_grad_()
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph)
# 计算梯度
grad_vis = torch.autograd.grad(
outputs=d_vis_score, inputs=interp_data,
grad_outputs=torch.ones_like(d_vis_score),
create_graph=True, retain_graph=True, only_inputs=True
)[0]
grad_topo = torch.autograd.grad(
outputs=d_topo_score, inputs=interp_graph.x,
grad_outputs=torch.ones_like(d_topo_score),
create_graph=True, retain_graph=True, only_inputs=True
)[0]
# 计算梯度的 L2 范数
grad_norm_vis = grad_vis.view(batch_size, -1).norm(2, dim=1)
grad_norm_topo = grad_topo.view(batch_size, -1).norm(2, dim=1)
# 计算梯度惩罚项
gp_loss_vis = ((grad_norm_vis - 1.0) ** 2).mean()
gp_loss_topo = ((grad_norm_topo - 1.0) ** 2).mean()
gp_loss = gp_loss_vis * VISION_WEIGHT + gp_loss_topo * TOPO_WEIGHT
# print(grad_norm_topo.mean().item(), grad_norm_vis.mean().item())
return gp_loss
def discriminator_loss(
self, critic, real_data: torch.Tensor,
real_graph: torch.Tensor, fake_data: torch.Tensor
):
""" 判别器损失函数 """
fake_graph = batch_convert_soft_map_to_graph(fake_data)
real_scores, _, _ = critic(real_data, real_graph)
fake_scores, _, _ = critic(fake_data, fake_graph)
# Wasserstein 距离
d_loss = fake_scores.mean() - real_scores.mean()
grad_loss = self.compute_gradient_penalty(critic, real_data, fake_data)
return d_loss, d_loss + self.lambda_gp * grad_loss
def generator_loss_one(self, critic, fake):
fake_graph = batch_convert_soft_map_to_graph(fake)
fake_scores, _, _ = critic(fake, fake_graph)
minamo_loss = -torch.mean(fake_scores)
class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
entrance_loss = entrance_constraint_loss(fake)
losses = [
minamo_loss * self.weight[0],
class_loss * self.weight[1],
entrance_loss * self.weight[2]
]
return sum(losses)
def generator_loss(self, critic, fake1, fake2):
""" 生成器损失函数 """
loss1 = self.generator_loss_one(critic, fake1)
loss2 = self.generator_loss_one(critic, fake2)
return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * js_divergence(fake1, fake2)

View File

@ -3,7 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .unet import GinkaUNet
from .output import GinkaOutput
from .input import GinkaInput, GinkaFeatureInput
from .input import GinkaInput
def print_memory(tag=""):
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
@ -13,9 +13,7 @@ class GinkaModel(nn.Module):
"""Ginka Model 模型定义部分
"""
super().__init__()
self.input = GinkaInput(feat_dim, 1, (32, 32))
self.feat_enc = GinkaFeatureInput(feat_dim, 2, base_ch)
self.unet = GinkaUNet(1, base_ch, base_ch, feat_dim)
self.unet = GinkaUNet(base_ch, base_ch, feat_dim)
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
def forward(self, x):
@ -25,12 +23,9 @@ class GinkaModel(nn.Module):
Returns:
logits: 输出logits [BS, num_classes, H, W]
"""
cond = x
feat = self.feat_enc(x)
x = self.input(x)
x = self.unet(x, feat, cond)
x = self.unet(x)
x = self.output(x)
return x, F.softmax(x, dim=1)
return F.softmax(x, dim=1)
# 检查显存占用
if __name__ == "__main__":
@ -48,8 +43,8 @@ if __name__ == "__main__":
print(f"输入形状: feat={feat.shape}")
print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}")
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
# print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
# print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

View File

@ -1,40 +1,51 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from shared.constant import FEAT_DIM
class GinkaTransformerEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
super().__init__()
in_dim = in_dim // token_size
hidden_dim = hidden_dim // token_size
out_dim = out_dim // token_size
self.embedding = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.LayerNorm(hidden_dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, token_size, hidden_dim))
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=ff_dim, batch_first=True),
num_layers=num_layers
)
self.fc = nn.Sequential(
nn.Linear(hidden_dim, out_dim),
nn.LayerNorm(out_dim)
)
def forward(self, x):
# 输入 [B, L, in_dim]
# 输出 [B, L, out_dim]
x = self.embedding(x) # [B, L, hidden_dim]
x = x + self.pos_embedding # [B, L, hidden_dim]
x = self.transformer(x) # [B, L, hidden_dim]
x = self.fc(x) # [B, L, out_dim]
return x
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
nn.ELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
nn.ELU(),
)
def forward(self, x):
return self.conv(x)
class ConditionFusionBlock(nn.Module):
def __init__(self):
super().__init__()
self.alpha = nn.Parameter(torch.tensor(0.5)) # 可学习融合系数
def forward(self, x, cond_feat):
return x + self.alpha * cond_feat # 残差融合
class FusionConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.fusion = ConditionFusionBlock()
def forward(self, x, feat):
x = self.conv(x)
x = self.fusion(x, feat)
return x
class GinkaEncoder(nn.Module):
"""编码器(下采样)部分"""
@ -42,12 +53,10 @@ class GinkaEncoder(nn.Module):
super().__init__()
self.conv = ConvBlock(in_ch, out_ch)
self.pool = nn.MaxPool2d(2)
self.fusion = ConditionFusionBlock()
def forward(self, x, feat):
def forward(self, x):
x = self.conv(x)
x = self.pool(x)
x = self.fusion(x, feat)
return x
class GinkaUpSample(nn.Module):
@ -55,8 +64,8 @@ class GinkaUpSample(nn.Module):
super().__init__()
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
nn.BatchNorm2d(out_ch),
nn.GELU(),
nn.InstanceNorm2d(out_ch),
nn.ELU(),
)
def forward(self, x):
@ -68,46 +77,57 @@ class GinkaDecoder(nn.Module):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.conv = ConvBlock(in_ch, out_ch)
self.fusion = ConditionFusionBlock()
def forward(self, x, skip, feat):
dec = self.upsample(x)
x = torch.cat([dec, skip], dim=1)
def forward(self, x, feat):
x = self.upsample(x)
x = torch.cat([x, feat], dim=1)
x = self.conv(x)
x = self.fusion(x, feat)
return x
class GinkaUNet(nn.Module):
def __init__(self, in_ch=1, base_ch=64, out_ch=32, feat_dim=1024):
def __init__(self, base_ch=64, out_ch=32, feat_dim=1024):
"""Ginka Model UNet 部分
"""
super().__init__()
self.in_conv = FusionConvBlock(in_ch, base_ch)
self.down1 = GinkaEncoder(base_ch, base_ch*2)
self.down2 = GinkaEncoder(base_ch*2, base_ch*4)
self.down3 = GinkaEncoder(base_ch*4, base_ch*8)
self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16)
self.up1 = GinkaDecoder(base_ch*16, base_ch*8)
self.up2 = GinkaDecoder(base_ch*8, base_ch*4)
self.up3 = GinkaDecoder(base_ch*4, base_ch*2)
self.up4 = GinkaDecoder(base_ch*2, base_ch)
self.input = GinkaTransformerEncoder(
in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
token_size=4, ff_dim=feat_dim*2, num_layers=4
)
self.down1 = ConvBlock(2, base_ch)
self.down2 = GinkaEncoder(base_ch, base_ch*2)
self.down3 = GinkaEncoder(base_ch*2, base_ch*4)
self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
self.bottleneck = GinkaTransformerEncoder(
in_dim=base_ch*8*4*4, hidden_dim=base_ch*8*4*4, out_dim=base_ch*8*4*4,
token_size=16, ff_dim=1024, num_layers=4
)
self.up1 = GinkaDecoder(base_ch*8, base_ch*4)
self.up2 = GinkaDecoder(base_ch*4, base_ch*2)
self.up3 = GinkaDecoder(base_ch*2, base_ch)
self.final = nn.Sequential(
nn.Conv2d(base_ch, out_ch, 1),
nn.InstanceNorm2d(out_ch),
nn.ELU(),
)
def forward(self, x, feat, cond):
x1 = self.in_conv(x, feat[0])
x2 = self.down1(x1, feat[1])
x3 = self.down2(x2, feat[2])
x4 = self.down3(x3, feat[3])
x5 = self.bottleneck(x4, feat[4])
def forward(self, x):
B, D = x.shape # [B, 1024]
x = x.view(B, 4, D // 4) # [B, 4, 256]
x = self.input(x) # [B, 4, 512]
x = x.view(B, 2, 32, 32) # [B, 2, 32, 32]
x1 = self.down1(x) # [B, 64, 32, 32]
x2 = self.down2(x1) # [B, 128, 16, 16]
x3 = self.down3(x2) # [B, 256, 8, 8]
x4 = self.down4(x3) # [B, 512, 4, 4]
x4 = x4.view(B, 512, 16).permute(0, 2, 1) # [B, 16, 512]
x4 = self.bottleneck(x4) # [B, 16, 512]
x4 = x4.permute(0, 2, 1).view(B, 512, 4, 4) # [B, 512, 4, 4]
x = self.up1(x5, x4, feat[3])
x = self.up2(x, x3, feat[2])
x = self.up3(x, x2, feat[1])
x = self.up4(x, x1, feat[0])
# 上采样
x = self.up1(x4, x3) # [B, 256, 8, 8]
x = self.up2(x, x2) # [B, 128, 16, 16]
x = self.up3(x, x1) # [B, 64, 32, 32]
return self.final(x)
return self.final(x) # [B, 32, 32, 32]

View File

@ -11,17 +11,19 @@ from tqdm import tqdm
import cv2
import numpy as np
from .model.model import GinkaModel
from .model.loss import GinkaLoss
from .model.loss import GinkaLoss, WGANGinkaLoss
from .dataset import GinkaDataset, MinamoGANDataset
from minamo.model.model import MinamoModel
from minamo.model.loss import MinamoLoss
from shared.image import matrix_to_image_cv
BATCH_SIZE = 32
EPOCHS_GINKA = 30
EPOCHS_MINAMO = 5
EPOCHS_GINKA = 5
EPOCHS_MINAMO = 2
SOCKET_PATH = "./tmp/ginka_uds"
LOSS_PATH = "result/gan/a-loss.txt"
REPLAY_PATH = "datasets/replay.bin"
VISION_ALPHA = 0
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
@ -32,6 +34,10 @@ os.makedirs("tmp", exist_ok=True)
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
f.write(f"---------- {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ----------\n")
if not os.path.exists(REPLAY_PATH):
with open(REPLAY_PATH, 'wb') as f:
f.write(b'\x00\x00\x00\x00')
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
@ -142,14 +148,16 @@ def train():
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE // 2, shuffle=True)
# 设定优化器与调度器
optimizer_ginka = optim.AdamW(ginka.parameters(), lr=1e-3)
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
criterion_ginka = GinkaLoss(minamo)
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-4)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6)
optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9))
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=EPOCHS_MINAMO, T_mult=2, eta_min=1e-6)
criterion_minamo = MinamoLoss()
criterion = WGANGinkaLoss()
# 用于生成图片
tile_dict = dict()
for file in os.listdir('tiles'):
@ -168,27 +176,17 @@ def train():
ginka.load_state_dict(data["model_state"], strict=False)
print("Train from loaded state.")
else:
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
criterion_ginka.weight[0] = 0.0
print("Waiting for client connection...")
conn, _ = server.accept()
print("Client connected.")
for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"):
# -------------------- 训练生成器
gen_list: np.ndarray = np.empty((0, 13, 13), np.int8)
prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32)
for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model"):
for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model", leave=False):
ginka.train()
minamo.eval()
total_loss = 0
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
if not args.resume and epoch == 10:
criterion_ginka.weight[0] = 0.5
for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"):
# 数据迁移到设备
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
@ -233,6 +231,8 @@ def train():
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
# 使用训练集生成 minamo 训练数据,更准确
gen_list: np.ndarray = np.empty((0, 13, 13), np.int8)
prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32)
with torch.no_grad():
for batch in ginka_dataloader:
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
@ -263,7 +263,7 @@ def train():
buf.extend(gen_bytes) # Map tensor
conn.sendall(buf)
data = parse_minamo_data(conn, prob_list)
minamo_dataset.set_data(data)
vis_sim = 0
topo_sim = 0
for _, _, vis, topo, _ in data:
@ -276,6 +276,44 @@ def train():
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
f.write(f'Cycle {cycle} | Ginka Vision Similarity: {vis_sim:.12f} | Ginka Topo Similarity: {topo_sim:.12f} | Ginka Loss: {avg_val_loss:.12f}')
# 经验回放部分
with open(REPLAY_PATH, 'r+b') as f:
# 读取文件开头获取总长度
f.seek(0)
count = struct.unpack('>i', f.read(4))[0] # 取出整数
if count > 0:
replay = np.random.choice(count, size=min(count, len(data) // 4), replace=False)
replay_data = np.empty((len(replay), 32, 13, 13))
for i, n in enumerate(replay):
f.seek(n * 32 * 13 * 13 + 4)
arr = np.frombuffer(f.read(32 * 13 * 13 * 4), dtype=np.float32).reshape(32, 13, 13)
replay_data[i] = arr
map_data: np.ndarray = replay_data.argmax(axis=1)
buf = bytearray()
buf.extend(struct.pack('>h', len(replay))) # Tensor count
buf.extend(struct.pack('>b', H)) # Map height
buf.extend(struct.pack('>b', W)) # Map width
buf.extend(map_data.astype(np.int8).tobytes()) # Map tensor
conn.sendall(buf)
data.extend(parse_minamo_data(conn, replay_data))
# 把新的内容写入文件末尾
to_write = np.random.choice(N, size=min(N, 100), replace=False)
write_data = bytearray()
for n in to_write:
write_data.extend(prob_list[n].tobytes())
f.seek(0, 2) # 定位到文件末尾
f.write(write_data)
f.seek(0) # 定位到文件开头
f.write(struct.pack('>i', count + len(to_write)))
f.flush() # 确保数据被刷新到磁盘
minamo_dataset.set_data(data)
# -------------------- 训练判别器
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
ginka.eval()
@ -283,21 +321,43 @@ def train():
total_loss = 0
for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"):
map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch)
map1, map2, vis_sim, topo_sim, graph1, graph2 = parse_minamo_batch(batch)
batch_size = map1.shape[0]
if map1.shape[0] == 1:
if batch_size == 1:
continue
# 前向传播
optimizer_minamo.zero_grad()
vision_feat1, topo_feat1 = minamo(map1, graph1)
vision_feat2, topo_feat2 = minamo(map2, graph2)
vis_feat_real, topo_feat_real = minamo(map1, graph1)
vis_feat_ref, topo_feat_ref = minamo(map2, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1)
# 生成假数据
with torch.no_grad():
fake_feat = torch.randn((batch_size, 1024), device=device)
fake_data = ginka(fake_feat)
# 创建插值样本
alpha = torch.rand((batch_size, 1, 1, 1), device=device)
interpolates = (alpha * map2 + (1 - alpha) * fake_data).requires_grad_(True)
vis_feat_fake, topo_feat_fake = minamo(fake_data)
vis_feat_interp, topo_feat_interp = minamo(interpolates)
vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1)
topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1)
vis_pred_fake = F.cosine_similarity(vis_feat_fake, vis_feat_ref, dim=1).unsqueeze(-1)
topo_pred_fake = F.cosine_similarity(topo_feat_fake, topo_feat_ref, dim=1).unsqueeze(-1)
vis_pred_interp = F.cosine_similarity(vis_feat_interp, vis_feat_ref, dim=1).unsqueeze(-1)
topo_pred_interp = F.cosine_similarity(topo_feat_interp, topo_feat_ref, dim=1).unsqueeze(-1)
# 计算相似度
score_real = F.l1_loss(vis_pred_real, vis_sim) * VISION_ALPHA + F.l1_loss(topo_pred_real, topo_sim) * (1 - VISION_ALPHA)
score_fake = vis_pred_fake * VISION_ALPHA + topo_pred_fake * (1 - VISION_ALPHA)
score_interp = vis_pred_interp * VISION_ALPHA + topo_pred_interp * (1 - VISION_ALPHA)
# 计算损失
loss = criterion_minamo(vision_pred, topo_pred, vision_simi, topo_simi)
loss = criterion.discriminator_loss(score_real, score_fake, score_interp)
# 反向传播
loss.backward()
@ -310,21 +370,21 @@ def train():
scheduler_minamo.step(epoch + 1)
# 每十轮推理一次验证集
if (epoch + 1) % 5 == 0:
if epoch + 1 == EPOCHS_MINAMO:
minamo.eval()
val_loss = 0
with torch.no_grad():
for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"):
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
vision_feat1, topo_feat1 = minamo(map1_val, graph1)
vision_feat2, topo_feat2 = minamo(map2_val, graph2)
vis_feat_real, topo_feat_real = minamo(map1_val, graph1)
vis_feat_ref, topo_feat_ref = minamo(map2_val, graph2)
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1)
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1)
vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1)
topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1)
# 计算损失
loss_val = criterion_minamo(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
loss_val = criterion_minamo(vis_pred_real, topo_pred_real, vision_simi_val, topo_simi_val)
val_loss += loss_val.item()
avg_val_loss = val_loss / len(minamo_dataloader_val)

164
ginka/train_wgan.py Normal file
View File

@ -0,0 +1,164 @@
import argparse
import os
from datetime import datetime
import torch
import torch.optim as optim
import cv2
from torch_geometric.loader import DataLoader
from tqdm import tqdm
from .model.model import GinkaModel
from .dataset import GinkaWGANDataset
from .model.loss import WGANGinkaLoss
from minamo.model.model import MinamoScoreModule
from shared.graph import batch_convert_soft_map_to_graph
from shared.image import matrix_to_image_cv
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
BATCH_SIZE = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/wgan", exist_ok=True)
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--state_ginka", type=str, default="result/ginka.pth")
parser.add_argument("--state_minamo", type=str, default="result/minamo.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--epochs", type=int, default=100)
args = parser.parse_args()
return args
def clip_weights(model, clip_value=0.01):
for param in model.parameters():
param.data = torch.clamp(param.data, -clip_value, clip_value)
def train():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
c_steps = 1
g_steps = 3
args = parse_arguments()
ginka = GinkaModel()
minamo = MinamoScoreModule()
ginka.to(device)
minamo.to(device)
dataset = GinkaWGANDataset(args.train, device)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
optimizer_ginka = optim.RMSprop(ginka.parameters(), lr=2e-4)
optimizer_minamo = optim.RMSprop(minamo.parameters(), lr=1e-5)
criterion = WGANGinkaLoss()
# 用于生成图片
tile_dict = dict()
for file in os.listdir('tiles'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
if args.resume:
data = torch.load(args.state_ginka, map_location=device)
ginka.load_state_dict(data["model_state"], strict=False)
data = torch.load(args.state_minamo, map_location=device)
minamo.load_state_dict(data["model_state"], strict=False)
print("Train from loaded state.")
for epoch in tqdm(range(args.epochs), desc="GAN Training"):
loss_total_minamo = torch.Tensor([0]).to(device)
loss_total_ginka = torch.Tensor([0]).to(device)
dis_total = torch.Tensor([0]).to(device)
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress"):
batch_size = real_data.size(0)
real_data = real_data.to(device)
real_graph = batch_convert_soft_map_to_graph(real_data)
optimizer_ginka.zero_grad()
# ---------- 训练判别器
for _ in range(c_steps):
# 生成假样本
optimizer_minamo.zero_grad()
z = torch.randn(batch_size, 1024, device=device)
fake_data = ginka(z)
fake_data = fake_data.detach()
# 计算判别器输出
# 反向传播
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
loss_d.backward()
# torch.nn.utils.clip_grad_norm_(minamo_vis.parameters(), max_norm=1.0)
optimizer_minamo.step()
loss_total_minamo += loss_d
dis_total += dis
# ---------- 训练生成器
for _ in range(g_steps):
z1 = torch.randn(batch_size, 1024, device=device)
z2 = torch.randn(batch_size, 1024, device=device)
fake_softmax1, fakse_softmax2 = ginka(z1), ginka(z2)
loss_g = criterion.generator_loss(minamo, fake_softmax1, fakse_softmax2)
loss_g.backward()
optimizer_ginka.step()
loss_total_ginka += loss_g
# tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}")
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
avg_dis = dis_total.item() / len(dataloader) / c_steps
tqdm.write(
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | Wasserstein Loss: {avg_dis:.8f} | Loss Ginka: {avg_loss_ginka:.8f} | Loss Minamo: {avg_loss_minamo:.8f}"
)
if avg_dis < -9:
g_steps = 7
elif avg_dis < -6:
g_steps = 5
elif avg_dis < -3:
g_steps = 3
else:
g_steps = 1
# 每五轮输出一次图片,并保存检查点
if (epoch + 1) % 5 == 0:
# 输出 20 张图片,每批次 4 张,一共五批
idx = 0
with torch.no_grad():
for _ in range(5):
z = torch.randn(4, 1024, device=device)
output = ginka(z)
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
for matrix in map_matrix:
image = matrix_to_image_cv(matrix, tile_dict)
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
idx += 1
# 保存检查点
torch.save({
"model_state": ginka.state_dict()
}, f"result/wgan/ginka-{epoch + 1}.pth")
torch.save({
"model_state": minamo.state_dict()
}, f"result/wgan/minamo-{epoch + 1}.pth")
print("Train ended.")
torch.save({
"model_state": ginka.state_dict()
}, f"result/ginka.pth")
torch.save({
"model_state": minamo.state_dict()
}, f"result/minamo.pth")
if __name__ == "__main__":
torch.set_num_threads(4)
train()

View File

@ -1,7 +1,9 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .vision import MinamoVisionModel
from .topo import MinamoTopoModel
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
class MinamoModel(nn.Module):
def __init__(self, tile_types=32):
@ -16,3 +18,62 @@ class MinamoModel(nn.Module):
topo_feat = self.topo_model(graph)
return vision_feat, topo_feat
class MinamoVisionScore(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
# 视觉相似度部分
self.vision_model = MinamoVisionModel(tile_types)
# 输出层
self.vision_fc = nn.Sequential(
nn.Linear(512, 2048),
nn.LeakyReLU(0.2),
nn.Linear(2048, 1)
)
def forward(self, map):
vision_feat = self.vision_model(map)
vis_score = self.vision_fc(vision_feat)
return vis_score
class MinamoTopoScore(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
# 拓扑相似度部分
self.topo_model = MinamoTopoModel(tile_types)
# 输出层
self.topo_fc = nn.Sequential(
nn.Linear(512, 2048),
nn.LeakyReLU(0.2),
nn.Linear(2048, 1)
)
def forward(self, graph):
topo_feat = self.topo_model(graph)
topo_score = self.topo_fc(topo_feat)
return topo_score
class MinamoScoreModule(nn.Module):
def __init__(self, tile_types=32):
super().__init__()
self.topo_model = MinamoTopoModel(tile_types)
self.vision_model = MinamoVisionModel(tile_types)
# 输出层
self.topo_fc = nn.Sequential(
nn.Linear(512, 2048),
nn.LeakyReLU(0.2),
nn.Linear(2048, 1)
)
self.vision_fc = nn.Sequential(
nn.Linear(512, 2048),
nn.LeakyReLU(0.2),
nn.Linear(2048, 1)
)
def forward(self, map, graph):
topo_feat = self.topo_model(graph)
topo_score = self.topo_fc(topo_feat)
vision_feat = self.vision_model(map)
vision_score = self.vision_fc(vision_feat)
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
return score, vision_score, topo_score

View File

@ -1,56 +1,50 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, global_max_pool
from torch.nn.utils import spectral_norm
from torch_geometric.nn import GATConv, global_max_pool, GCNConv, global_mean_pool
from torch_geometric.data import Data
class MinamoTopoModel(nn.Module):
def __init__(
self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512, mlp_dim=512
self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512, feat_dim=512
):
super().__init__()
# 传入 softmax 概率值,直接映射
self.input_proj = nn.Linear(tile_types, emb_dim)
self.input_proj = nn.Sequential(
nn.Linear(tile_types, emb_dim),
nn.LeakyReLU(0.2)
)
# 图卷积层
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2)
self.conv2 = GATConv(hidden_dim*16, hidden_dim*2, heads=8)
self.conv3 = GATConv(hidden_dim*16, hidden_dim*2, heads=8)
self.conv4 = GATConv(hidden_dim*16, out_dim, heads=1)
self.conv1 = GATConv(emb_dim, hidden_dim, heads=8)
self.conv2 = GATConv(hidden_dim*8, hidden_dim, heads=8)
self.conv3 = GATConv(hidden_dim*8, out_dim, heads=1)
# 正则化
self.norm1 = nn.LayerNorm(hidden_dim*16)
self.norm2 = nn.LayerNorm(hidden_dim*16)
self.norm3 = nn.LayerNorm(hidden_dim*16)
self.norm4 = nn.LayerNorm(out_dim)
# self.norm1 = nn.LayerNorm(hidden_dim*8)
# self.norm2 = nn.LayerNorm(hidden_dim*8)
# self.norm3 = nn.LayerNorm(out_dim)
self.drop = nn.Dropout(0.3)
# 增强MLP
self.fc = nn.Sequential(
nn.Linear(out_dim, mlp_dim),
nn.Linear(out_dim, feat_dim),
nn.LeakyReLU(0.2)
)
def forward(self, graph: Data):
x = self.input_proj(graph.x)
x = self.conv1(x, graph.edge_index)
x = F.relu(self.norm1(x))
x = F.leaky_relu(x, 0.2)
x = self.conv2(x, graph.edge_index)
x = F.relu(self.norm2(x))
x = F.leaky_relu(x, 0.2)
x = self.conv3(x, graph.edge_index)
x = F.relu(self.norm3(x))
x = self.conv4(x, graph.edge_index)
x = F.relu(self.norm4(x))
x = F.leaky_relu(x, 0.2)
# 池化
x = self.drop(x)
x = global_max_pool(x, graph.batch)
x = global_mean_pool(x, graph.batch)
topo_vec = self.fc(x)
# 归一化
return F.normalize(topo_vec, p=2, dim=-1)
return topo_vec

View File

@ -1,14 +1,28 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18
from torch.nn.utils import spectral_norm
class MinamoVisionModel(nn.Module):
def __init__(self, tile_types=32, out_dim=512):
def __init__(self, in_ch=32, out_dim=512):
super().__init__()
self.resnet = resnet18(num_classes=out_dim)
self.resnet.conv1 = nn.Conv2d(tile_types, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.conv = nn.Sequential(
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3, stride=2)), # 6*6
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #4*4
nn.LeakyReLU(0.2),
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 2*2
nn.LeakyReLU(0.2),
nn.Flatten()
)
self.fc = nn.Sequential(
spectral_norm(nn.Linear(in_ch*8*2*2, out_dim))
)
def forward(self, x):
vision_vec = self.resnet(x)
return F.normalize(vision_vec, p=2, dim=-1) # 归一化
x = self.conv(x)
x = self.fc(x)
return x

6
shared/constant.py Normal file
View File

@ -0,0 +1,6 @@
VIS_DIM = 512
TOPO_DIM = 512
FEAT_DIM = 1024
VISION_WEIGHT = 0
TOPO_WEIGHT = 1

View File

@ -1,5 +1,6 @@
import torch
from torch_geometric.data import Data, Batch
from torch_geometric.utils import add_self_loops
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
"""
@ -44,6 +45,8 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C]
dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C]
edge_attr = (src_feat + dst_feat) / 2 * edge_mask # [E, C]
edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
return Data(
x=node_features,