mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
refactor: 改为 Wasserstein GAN
This commit is contained in:
parent
d7209a68a2
commit
29cfb4d029
@ -4,7 +4,7 @@ import { MinamoTrainData } from './types';
|
||||
import { generateTrainData } from './process/minamo';
|
||||
|
||||
const SOCKET_FILE = '../tmp/ginka_uds';
|
||||
const [refer] = process.argv.slice(2);
|
||||
const [refer, replayPath = '../datasets/replay.bin'] = process.argv.slice(2);
|
||||
|
||||
let id = 0;
|
||||
|
||||
@ -43,7 +43,7 @@ function generateGANData(
|
||||
const size2: [number, number] = [map[0].length, map.length];
|
||||
if (size1[0] !== size2[0] || size1[1] !== size2[1]) return [];
|
||||
|
||||
return generateTrainData(v, id2, floor.map, map, size1);
|
||||
return generateTrainData(v, id2, floor.map, map, size1, false, false, false);
|
||||
});
|
||||
return data.flat();
|
||||
}
|
||||
@ -104,7 +104,7 @@ class DataReceiver {
|
||||
console.log(`UDS IPC connected successfully.`);
|
||||
});
|
||||
|
||||
client.on('data', buffer => {
|
||||
client.on('data', async buffer => {
|
||||
const data = DataReceiver.check(buffer);
|
||||
if (!data) return;
|
||||
|
||||
@ -112,20 +112,19 @@ class DataReceiver {
|
||||
const simData = map.map(v => generateGANData(keys, referTower, v));
|
||||
const rc = 0;
|
||||
const compareData = simData.flat();
|
||||
const reviewData: MinamoTrainData[] = [];
|
||||
|
||||
// 数据通讯 node 输出协议,单位字节:
|
||||
// 2 - Tensor count; 2 - Review count. Review is right behind train data;
|
||||
// 2 - Tensor count; 2 - Replay count. Replay is right behind train data;
|
||||
// 1*tc - Compare count for every map tensor delivered.
|
||||
// 2*4*(N+rc) - Vision similarity and topo similarity, like vis, topo, vis, topo;
|
||||
// N*1*H*W - Compare map for every map tensor. rc*2*H*W - Review map tensor.
|
||||
// N*1*H*W - Compare map for every map tensor. rc*2*H*W - Replay map tensor.
|
||||
const toSend = Buffer.alloc(
|
||||
2 + // Tensor count
|
||||
2 + // Review count
|
||||
2 + // Replay count
|
||||
1 * count + // Compare count
|
||||
2 * 4 * (compareData.length + rc) + // Similarity data
|
||||
compareData.length * 1 * h * w + // Compare map
|
||||
rc * 2 * h * w, // Review map
|
||||
rc * 2 * h * w, // Replay map
|
||||
0
|
||||
);
|
||||
console.log(
|
||||
@ -141,7 +140,7 @@ class DataReceiver {
|
||||
|
||||
let offset = 0;
|
||||
toSend.writeInt16BE(count); // Tensor count
|
||||
toSend.writeInt16BE(0, 2); // Review count
|
||||
toSend.writeInt16BE(0, 2); // Replay count
|
||||
offset += 2 + 2;
|
||||
// Compare count
|
||||
toSend.set(
|
||||
@ -164,13 +163,6 @@ class DataReceiver {
|
||||
offset // Set from Compare map
|
||||
);
|
||||
offset += compareData.length * 1 * h * w;
|
||||
if (reviewData.length > 0) {
|
||||
// Review map
|
||||
toSend.set(
|
||||
new Uint8Array(reviewData.map(v => [v.map1, v.map2]).flat(4)),
|
||||
offset // Set from last chunk
|
||||
);
|
||||
}
|
||||
|
||||
client.write(toSend);
|
||||
});
|
||||
|
||||
@ -192,7 +192,10 @@ export function generateTrainData(
|
||||
id2: string,
|
||||
map1: number[][],
|
||||
map2: number[][],
|
||||
size: [number, number]
|
||||
size: [number, number],
|
||||
hasSelf: boolean = true,
|
||||
hasTransform: boolean = true,
|
||||
hasSimilar: boolean = true
|
||||
) {
|
||||
const topoSimilarity = compareMap(id1, id2, map1, map2);
|
||||
const visionSimilarity = calculateVisualSimilarity(map1, map2);
|
||||
@ -205,39 +208,47 @@ export function generateTrainData(
|
||||
};
|
||||
const data: MinamoTrainData[] = [];
|
||||
data.push(train);
|
||||
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
||||
const self1 = `${id1}:${id1}`;
|
||||
const self2 = `${id2}:${id2}`;
|
||||
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
||||
if (selfTrain.includes(self1)) {
|
||||
const selfTrain1: MinamoTrainData = {
|
||||
map1: map1,
|
||||
map2: map1,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: 1,
|
||||
size: size
|
||||
};
|
||||
data.push(selfTrain1);
|
||||
if (hasSelf) {
|
||||
// 自身与自身对比的训练集,保证模型对相同地图输出 1
|
||||
const self1 = `${id1}:${id1}`;
|
||||
const self2 = `${id2}:${id2}`;
|
||||
const selfTrain = chooseFrom([self1, self2], Math.floor(Math.random() * 1));
|
||||
if (selfTrain.includes(self1)) {
|
||||
const selfTrain1: MinamoTrainData = {
|
||||
map1: map1,
|
||||
map2: map1,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: 1,
|
||||
size: size
|
||||
};
|
||||
data.push(selfTrain1);
|
||||
}
|
||||
if (selfTrain.includes(self2)) {
|
||||
const selfTrain2: MinamoTrainData = {
|
||||
map1: map2,
|
||||
map2: map2,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: 1,
|
||||
size: size
|
||||
};
|
||||
data.push(selfTrain2);
|
||||
}
|
||||
}
|
||||
if (selfTrain.includes(self2)) {
|
||||
const selfTrain2: MinamoTrainData = {
|
||||
map1: map2,
|
||||
map2: map2,
|
||||
topoSimilarity: 1,
|
||||
visionSimilarity: 1,
|
||||
size: size
|
||||
};
|
||||
data.push(selfTrain2);
|
||||
if (hasTransform) {
|
||||
const transform = generateTransformData(
|
||||
id1,
|
||||
id2,
|
||||
map1,
|
||||
map2,
|
||||
topoSimilarity
|
||||
);
|
||||
data.push(...transform.map(v => v[1]))
|
||||
}
|
||||
const transform = generateTransformData(
|
||||
id1,
|
||||
id2,
|
||||
map1,
|
||||
map2,
|
||||
topoSimilarity
|
||||
);
|
||||
const similar = generateSimilarData(id1, map1);
|
||||
return [...data, ...transform.map(v => v[1]), ...similar.map(v => v[1])];
|
||||
if (hasSimilar) {
|
||||
const similar = generateSimilarData(id1, map1);
|
||||
data.push(...similar.map(v => v[1]))
|
||||
}
|
||||
return data;
|
||||
}
|
||||
|
||||
export function generatePair(
|
||||
|
||||
@ -87,8 +87,7 @@ function weisfeilerLehmanIteration(
|
||||
});
|
||||
weight *= decay;
|
||||
});
|
||||
// 把每个节点的原始标签也加上,权重使用最远权重再衰减1次,可以认为是资源重复率
|
||||
weight *= decay;
|
||||
// 把每个节点的原始标签也加上,权重使用最远权重,可以认为是资源重复率
|
||||
nodes.forEach(node => {
|
||||
if (!numMap.has(node.originalLabel)) {
|
||||
numMap.set(node.originalLabel, weight);
|
||||
|
||||
@ -51,6 +51,25 @@ class GinkaDataset(Dataset):
|
||||
"target": target,
|
||||
}
|
||||
|
||||
class GinkaWGANDataset(Dataset):
|
||||
def __init__(self, data_path: str, device):
|
||||
self.data = load_data(data_path) # 自定义数据加载函数
|
||||
self.device = device
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data[idx]
|
||||
|
||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
min_main = random.uniform(0.75, 0.9)
|
||||
max_main = random.uniform(0.9, 1)
|
||||
epsilon = random.uniform(0, 0.25)
|
||||
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon).to(self.device)
|
||||
|
||||
return target_smooth
|
||||
|
||||
class MinamoGANDataset(Dataset):
|
||||
def __init__(self, refer_data_path):
|
||||
self.refer = load_minamo_gan_data(load_data(refer_data_path))
|
||||
@ -67,6 +86,7 @@ class MinamoGANDataset(Dataset):
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# 假定 map2 是参考地图
|
||||
item = self.data[idx]
|
||||
|
||||
map1, map2, vis_sim, topo_sim, review = item
|
||||
|
||||
@ -3,8 +3,10 @@ from tqdm import tqdm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.data import Data
|
||||
from minamo.model.model import MinamoModel
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
|
||||
CLASS_NUM = 32
|
||||
ILLEGAL_MAX_NUM = 12
|
||||
@ -260,10 +262,129 @@ class GinkaLoss(nn.Module):
|
||||
)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
minamo_loss * self.weight[0] * 4,
|
||||
class_loss * self.weight[1],
|
||||
entrance_loss * self.weight[2],
|
||||
count_loss * self.weight[3]
|
||||
]
|
||||
|
||||
return sum(losses)
|
||||
|
||||
# 对图像数据进行插值
|
||||
def interpolate_data(real_data, fake_data, epsilon):
|
||||
return epsilon * real_data + (1 - epsilon) * fake_data
|
||||
|
||||
# 对节点特征进行插值,但保持边连接关系不变
|
||||
def interpolate_graph_features(real_graph, fake_graph, epsilon=0.5):
|
||||
# 插值节点特征
|
||||
x_real, x_fake = real_graph.x, fake_graph.x
|
||||
x_interp = epsilon * x_real + (1 - epsilon) * x_fake
|
||||
|
||||
# 保持边连接关系和边特征不变
|
||||
edge_index_interp = real_graph.edge_index # 保持边连接关系
|
||||
edge_attr_interp = real_graph.edge_attr # 如果有边特征,保持不变
|
||||
|
||||
return Data(x=x_interp, edge_index=edge_index_interp, edge_attr=edge_attr_interp)
|
||||
|
||||
def js_divergence(P, Q, epsilon=1e-10):
|
||||
"""
|
||||
输入:
|
||||
P, Q: [B, C, H, W], 已通过 Softmax 处理
|
||||
输出:
|
||||
JS 散度标量(全局平均)
|
||||
"""
|
||||
# 转换为 [B, H, W, C] 以便在最后一维计算概率分布
|
||||
P = P.permute(0, 2, 3, 1) # [B, H, W, C]
|
||||
Q = Q.permute(0, 2, 3, 1)
|
||||
|
||||
# 平均分布 M = (P + Q)/2
|
||||
M = 0.5 * (P + Q)
|
||||
|
||||
# 计算 KL(P||M) 和 KL(Q||M)
|
||||
kl_pm = F.kl_div(torch.log(M + epsilon), P, reduction='none', log_target=False).sum(dim=-1) # [B, H, W]
|
||||
kl_qm = F.kl_div(torch.log(M + epsilon), Q, reduction='none', log_target=False).sum(dim=-1) # [B, H, W]
|
||||
|
||||
# JS 散度 = 0.5*(KL(P||M) + KL(Q||M))
|
||||
js = 0.5 * (kl_pm + kl_qm)
|
||||
|
||||
# 全局平均(可替换为其他聚合方式)
|
||||
return js.mean() # 标量
|
||||
|
||||
class WGANGinkaLoss:
|
||||
def __init__(self, lambda_gp=10, weight=[0.7, 0.2, 0.1], diversity_lamda=0):
|
||||
self.lambda_gp = lambda_gp # 梯度惩罚系数
|
||||
self.weight = weight
|
||||
self.diversity_lamda = diversity_lamda
|
||||
|
||||
def compute_gradient_penalty(self, critic, real_data, fake_data):
|
||||
# 进行插值
|
||||
batch_size = real_data.size(0)
|
||||
epsilon_data = torch.randn(batch_size, 1, 1, 1, device=real_data.device)
|
||||
interp_data = interpolate_data(real_data, fake_data, epsilon_data)
|
||||
interp_graph = batch_convert_soft_map_to_graph(interp_data)
|
||||
|
||||
# 对图像进行反向传播并计算梯度
|
||||
interp_data.requires_grad_()
|
||||
interp_graph.x.requires_grad_()
|
||||
|
||||
_, d_vis_score, d_topo_score = critic(interp_data, interp_graph)
|
||||
|
||||
# 计算梯度
|
||||
grad_vis = torch.autograd.grad(
|
||||
outputs=d_vis_score, inputs=interp_data,
|
||||
grad_outputs=torch.ones_like(d_vis_score),
|
||||
create_graph=True, retain_graph=True, only_inputs=True
|
||||
)[0]
|
||||
grad_topo = torch.autograd.grad(
|
||||
outputs=d_topo_score, inputs=interp_graph.x,
|
||||
grad_outputs=torch.ones_like(d_topo_score),
|
||||
create_graph=True, retain_graph=True, only_inputs=True
|
||||
)[0]
|
||||
|
||||
# 计算梯度的 L2 范数
|
||||
grad_norm_vis = grad_vis.view(batch_size, -1).norm(2, dim=1)
|
||||
grad_norm_topo = grad_topo.view(batch_size, -1).norm(2, dim=1)
|
||||
# 计算梯度惩罚项
|
||||
gp_loss_vis = ((grad_norm_vis - 1.0) ** 2).mean()
|
||||
gp_loss_topo = ((grad_norm_topo - 1.0) ** 2).mean()
|
||||
gp_loss = gp_loss_vis * VISION_WEIGHT + gp_loss_topo * TOPO_WEIGHT
|
||||
# print(grad_norm_topo.mean().item(), grad_norm_vis.mean().item())
|
||||
|
||||
return gp_loss
|
||||
|
||||
def discriminator_loss(
|
||||
self, critic, real_data: torch.Tensor,
|
||||
real_graph: torch.Tensor, fake_data: torch.Tensor
|
||||
):
|
||||
""" 判别器损失函数 """
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake_data)
|
||||
real_scores, _, _ = critic(real_data, real_graph)
|
||||
fake_scores, _, _ = critic(fake_data, fake_graph)
|
||||
|
||||
# Wasserstein 距离
|
||||
d_loss = fake_scores.mean() - real_scores.mean()
|
||||
grad_loss = self.compute_gradient_penalty(critic, real_data, fake_data)
|
||||
|
||||
return d_loss, d_loss + self.lambda_gp * grad_loss
|
||||
|
||||
def generator_loss_one(self, critic, fake):
|
||||
fake_graph = batch_convert_soft_map_to_graph(fake)
|
||||
fake_scores, _, _ = critic(fake, fake_graph)
|
||||
minamo_loss = -torch.mean(fake_scores)
|
||||
class_loss = outer_border_constraint_loss(fake) + inner_constraint_loss(fake)
|
||||
entrance_loss = entrance_constraint_loss(fake)
|
||||
|
||||
losses = [
|
||||
minamo_loss * self.weight[0],
|
||||
class_loss * self.weight[1],
|
||||
entrance_loss * self.weight[2]
|
||||
]
|
||||
|
||||
return sum(losses)
|
||||
|
||||
def generator_loss(self, critic, fake1, fake2):
|
||||
""" 生成器损失函数 """
|
||||
loss1 = self.generator_loss_one(critic, fake1)
|
||||
loss2 = self.generator_loss_one(critic, fake2)
|
||||
|
||||
return loss1 * 0.5 + loss2 * 0.5 - self.diversity_lamda * js_divergence(fake1, fake2)
|
||||
|
||||
@ -3,7 +3,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .unet import GinkaUNet
|
||||
from .output import GinkaOutput
|
||||
from .input import GinkaInput, GinkaFeatureInput
|
||||
from .input import GinkaInput
|
||||
|
||||
def print_memory(tag=""):
|
||||
print(f"{tag} | 当前显存: {torch.cuda.memory_allocated() / 1024**2:.2f} MB, 最大显存: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
|
||||
@ -13,9 +13,7 @@ class GinkaModel(nn.Module):
|
||||
"""Ginka Model 模型定义部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.input = GinkaInput(feat_dim, 1, (32, 32))
|
||||
self.feat_enc = GinkaFeatureInput(feat_dim, 2, base_ch)
|
||||
self.unet = GinkaUNet(1, base_ch, base_ch, feat_dim)
|
||||
self.unet = GinkaUNet(base_ch, base_ch, feat_dim)
|
||||
self.output = GinkaOutput(base_ch, out_ch, (13, 13))
|
||||
|
||||
def forward(self, x):
|
||||
@ -25,12 +23,9 @@ class GinkaModel(nn.Module):
|
||||
Returns:
|
||||
logits: 输出logits [BS, num_classes, H, W]
|
||||
"""
|
||||
cond = x
|
||||
feat = self.feat_enc(x)
|
||||
x = self.input(x)
|
||||
x = self.unet(x, feat, cond)
|
||||
x = self.unet(x)
|
||||
x = self.output(x)
|
||||
return x, F.softmax(x, dim=1)
|
||||
return F.softmax(x, dim=1)
|
||||
|
||||
# 检查显存占用
|
||||
if __name__ == "__main__":
|
||||
@ -48,8 +43,8 @@ if __name__ == "__main__":
|
||||
|
||||
print(f"输入形状: feat={feat.shape}")
|
||||
print(f"输出形状: output={output.shape}, softmax={output_softmax.shape}")
|
||||
print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
||||
print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
|
||||
# print(f"Input parameters: {sum(p.numel() for p in model.input.parameters())}")
|
||||
# print(f"Feature Encoder parameters: {sum(p.numel() for p in model.feat_enc.parameters())}")
|
||||
print(f"UNet parameters: {sum(p.numel() for p in model.unet.parameters())}")
|
||||
print(f"Output parameters: {sum(p.numel() for p in model.output.parameters())}")
|
||||
print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")
|
||||
|
||||
@ -1,40 +1,51 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from shared.constant import FEAT_DIM
|
||||
|
||||
class GinkaTransformerEncoder(nn.Module):
|
||||
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
|
||||
super().__init__()
|
||||
in_dim = in_dim // token_size
|
||||
hidden_dim = hidden_dim // token_size
|
||||
out_dim = out_dim // token_size
|
||||
self.embedding = nn.Sequential(
|
||||
nn.Linear(in_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim)
|
||||
)
|
||||
self.pos_embedding = nn.Parameter(torch.randn(1, token_size, hidden_dim))
|
||||
self.transformer = nn.TransformerEncoder(
|
||||
nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=ff_dim, batch_first=True),
|
||||
num_layers=num_layers
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(hidden_dim, out_dim),
|
||||
nn.LayerNorm(out_dim)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# 输入 [B, L, in_dim]
|
||||
# 输出 [B, L, out_dim]
|
||||
x = self.embedding(x) # [B, L, hidden_dim]
|
||||
x = x + self.pos_embedding # [B, L, hidden_dim]
|
||||
x = self.transformer(x) # [B, L, hidden_dim]
|
||||
x = self.fc(x) # [B, L, out_dim]
|
||||
return x
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU(),
|
||||
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv(x)
|
||||
|
||||
class ConditionFusionBlock(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.alpha = nn.Parameter(torch.tensor(0.5)) # 可学习融合系数
|
||||
|
||||
def forward(self, x, cond_feat):
|
||||
return x + self.alpha * cond_feat # 残差融合
|
||||
|
||||
class FusionConvBlock(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.fusion = ConditionFusionBlock()
|
||||
|
||||
def forward(self, x, feat):
|
||||
x = self.conv(x)
|
||||
x = self.fusion(x, feat)
|
||||
return x
|
||||
|
||||
class GinkaEncoder(nn.Module):
|
||||
"""编码器(下采样)部分"""
|
||||
@ -42,12 +53,10 @@ class GinkaEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.pool = nn.MaxPool2d(2)
|
||||
self.fusion = ConditionFusionBlock()
|
||||
|
||||
def forward(self, x, feat):
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.pool(x)
|
||||
x = self.fusion(x, feat)
|
||||
return x
|
||||
|
||||
class GinkaUpSample(nn.Module):
|
||||
@ -55,8 +64,8 @@ class GinkaUpSample(nn.Module):
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
|
||||
nn.BatchNorm2d(out_ch),
|
||||
nn.GELU(),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU(),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
@ -68,46 +77,57 @@ class GinkaDecoder(nn.Module):
|
||||
super().__init__()
|
||||
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
||||
self.conv = ConvBlock(in_ch, out_ch)
|
||||
self.fusion = ConditionFusionBlock()
|
||||
|
||||
def forward(self, x, skip, feat):
|
||||
dec = self.upsample(x)
|
||||
x = torch.cat([dec, skip], dim=1)
|
||||
def forward(self, x, feat):
|
||||
x = self.upsample(x)
|
||||
x = torch.cat([x, feat], dim=1)
|
||||
x = self.conv(x)
|
||||
x = self.fusion(x, feat)
|
||||
return x
|
||||
|
||||
class GinkaUNet(nn.Module):
|
||||
def __init__(self, in_ch=1, base_ch=64, out_ch=32, feat_dim=1024):
|
||||
def __init__(self, base_ch=64, out_ch=32, feat_dim=1024):
|
||||
"""Ginka Model UNet 部分
|
||||
"""
|
||||
super().__init__()
|
||||
self.in_conv = FusionConvBlock(in_ch, base_ch)
|
||||
self.down1 = GinkaEncoder(base_ch, base_ch*2)
|
||||
self.down2 = GinkaEncoder(base_ch*2, base_ch*4)
|
||||
self.down3 = GinkaEncoder(base_ch*4, base_ch*8)
|
||||
|
||||
self.bottleneck = GinkaEncoder(base_ch*8, base_ch*16)
|
||||
|
||||
self.up1 = GinkaDecoder(base_ch*16, base_ch*8)
|
||||
self.up2 = GinkaDecoder(base_ch*8, base_ch*4)
|
||||
self.up3 = GinkaDecoder(base_ch*4, base_ch*2)
|
||||
self.up4 = GinkaDecoder(base_ch*2, base_ch)
|
||||
self.input = GinkaTransformerEncoder(
|
||||
in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
|
||||
token_size=4, ff_dim=feat_dim*2, num_layers=4
|
||||
)
|
||||
self.down1 = ConvBlock(2, base_ch)
|
||||
self.down2 = GinkaEncoder(base_ch, base_ch*2)
|
||||
self.down3 = GinkaEncoder(base_ch*2, base_ch*4)
|
||||
self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
|
||||
self.bottleneck = GinkaTransformerEncoder(
|
||||
in_dim=base_ch*8*4*4, hidden_dim=base_ch*8*4*4, out_dim=base_ch*8*4*4,
|
||||
token_size=16, ff_dim=1024, num_layers=4
|
||||
)
|
||||
|
||||
self.up1 = GinkaDecoder(base_ch*8, base_ch*4)
|
||||
self.up2 = GinkaDecoder(base_ch*4, base_ch*2)
|
||||
self.up3 = GinkaDecoder(base_ch*2, base_ch)
|
||||
|
||||
self.final = nn.Sequential(
|
||||
nn.Conv2d(base_ch, out_ch, 1),
|
||||
nn.InstanceNorm2d(out_ch),
|
||||
nn.ELU(),
|
||||
)
|
||||
|
||||
def forward(self, x, feat, cond):
|
||||
x1 = self.in_conv(x, feat[0])
|
||||
x2 = self.down1(x1, feat[1])
|
||||
x3 = self.down2(x2, feat[2])
|
||||
x4 = self.down3(x3, feat[3])
|
||||
x5 = self.bottleneck(x4, feat[4])
|
||||
def forward(self, x):
|
||||
B, D = x.shape # [B, 1024]
|
||||
x = x.view(B, 4, D // 4) # [B, 4, 256]
|
||||
x = self.input(x) # [B, 4, 512]
|
||||
x = x.view(B, 2, 32, 32) # [B, 2, 32, 32]
|
||||
x1 = self.down1(x) # [B, 64, 32, 32]
|
||||
x2 = self.down2(x1) # [B, 128, 16, 16]
|
||||
x3 = self.down3(x2) # [B, 256, 8, 8]
|
||||
x4 = self.down4(x3) # [B, 512, 4, 4]
|
||||
x4 = x4.view(B, 512, 16).permute(0, 2, 1) # [B, 16, 512]
|
||||
x4 = self.bottleneck(x4) # [B, 16, 512]
|
||||
x4 = x4.permute(0, 2, 1).view(B, 512, 4, 4) # [B, 512, 4, 4]
|
||||
|
||||
x = self.up1(x5, x4, feat[3])
|
||||
x = self.up2(x, x3, feat[2])
|
||||
x = self.up3(x, x2, feat[1])
|
||||
x = self.up4(x, x1, feat[0])
|
||||
# 上采样
|
||||
x = self.up1(x4, x3) # [B, 256, 8, 8]
|
||||
x = self.up2(x, x2) # [B, 128, 16, 16]
|
||||
x = self.up3(x, x1) # [B, 64, 32, 32]
|
||||
|
||||
return self.final(x)
|
||||
return self.final(x) # [B, 32, 32, 32]
|
||||
|
||||
@ -11,17 +11,19 @@ from tqdm import tqdm
|
||||
import cv2
|
||||
import numpy as np
|
||||
from .model.model import GinkaModel
|
||||
from .model.loss import GinkaLoss
|
||||
from .model.loss import GinkaLoss, WGANGinkaLoss
|
||||
from .dataset import GinkaDataset, MinamoGANDataset
|
||||
from minamo.model.model import MinamoModel
|
||||
from minamo.model.loss import MinamoLoss
|
||||
from shared.image import matrix_to_image_cv
|
||||
|
||||
BATCH_SIZE = 32
|
||||
EPOCHS_GINKA = 30
|
||||
EPOCHS_MINAMO = 5
|
||||
EPOCHS_GINKA = 5
|
||||
EPOCHS_MINAMO = 2
|
||||
SOCKET_PATH = "./tmp/ginka_uds"
|
||||
LOSS_PATH = "result/gan/a-loss.txt"
|
||||
REPLAY_PATH = "datasets/replay.bin"
|
||||
VISION_ALPHA = 0
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
@ -32,6 +34,10 @@ os.makedirs("tmp", exist_ok=True)
|
||||
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
||||
f.write(f"---------- {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ----------\n")
|
||||
|
||||
if not os.path.exists(REPLAY_PATH):
|
||||
with open(REPLAY_PATH, 'wb') as f:
|
||||
f.write(b'\x00\x00\x00\x00')
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
@ -142,14 +148,16 @@ def train():
|
||||
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE // 2, shuffle=True)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer_ginka = optim.AdamW(ginka.parameters(), lr=1e-3)
|
||||
optimizer_ginka = optim.Adam(ginka.parameters(), lr=1e-4, betas=(0.0, 0.9))
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion_ginka = GinkaLoss(minamo)
|
||||
|
||||
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-4)
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6)
|
||||
optimizer_minamo = optim.Adam(minamo.parameters(), lr=2e-5, betas=(0.0, 0.9))
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=EPOCHS_MINAMO, T_mult=2, eta_min=1e-6)
|
||||
criterion_minamo = MinamoLoss()
|
||||
|
||||
criterion = WGANGinkaLoss()
|
||||
|
||||
# 用于生成图片
|
||||
tile_dict = dict()
|
||||
for file in os.listdir('tiles'):
|
||||
@ -168,27 +176,17 @@ def train():
|
||||
ginka.load_state_dict(data["model_state"], strict=False)
|
||||
print("Train from loaded state.")
|
||||
|
||||
else:
|
||||
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
|
||||
criterion_ginka.weight[0] = 0.0
|
||||
|
||||
print("Waiting for client connection...")
|
||||
conn, _ = server.accept()
|
||||
print("Client connected.")
|
||||
|
||||
for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"):
|
||||
# -------------------- 训练生成器
|
||||
gen_list: np.ndarray = np.empty((0, 13, 13), np.int8)
|
||||
prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32)
|
||||
for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model"):
|
||||
for epoch in tqdm(range(EPOCHS_GINKA), desc="Training Ginka Model", leave=False):
|
||||
ginka.train()
|
||||
minamo.eval()
|
||||
total_loss = 0
|
||||
|
||||
# 从头开始训练的,在第 10 个 epoch 将 minamo 损失值权重改回来
|
||||
if not args.resume and epoch == 10:
|
||||
criterion_ginka.weight[0] = 0.5
|
||||
|
||||
for batch in tqdm(ginka_dataloader, leave=False, desc="Epoch Progress"):
|
||||
# 数据迁移到设备
|
||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||
@ -233,6 +231,8 @@ def train():
|
||||
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
# 使用训练集生成 minamo 训练数据,更准确
|
||||
gen_list: np.ndarray = np.empty((0, 13, 13), np.int8)
|
||||
prob_list: np.ndarray = np.empty((0, 32, 13, 13), np.float32)
|
||||
with torch.no_grad():
|
||||
for batch in ginka_dataloader:
|
||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||
@ -263,7 +263,7 @@ def train():
|
||||
buf.extend(gen_bytes) # Map tensor
|
||||
conn.sendall(buf)
|
||||
data = parse_minamo_data(conn, prob_list)
|
||||
minamo_dataset.set_data(data)
|
||||
|
||||
vis_sim = 0
|
||||
topo_sim = 0
|
||||
for _, _, vis, topo, _ in data:
|
||||
@ -276,6 +276,44 @@ def train():
|
||||
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
||||
f.write(f'Cycle {cycle} | Ginka Vision Similarity: {vis_sim:.12f} | Ginka Topo Similarity: {topo_sim:.12f} | Ginka Loss: {avg_val_loss:.12f}')
|
||||
|
||||
# 经验回放部分
|
||||
with open(REPLAY_PATH, 'r+b') as f:
|
||||
# 读取文件开头获取总长度
|
||||
f.seek(0)
|
||||
count = struct.unpack('>i', f.read(4))[0] # 取出整数
|
||||
if count > 0:
|
||||
replay = np.random.choice(count, size=min(count, len(data) // 4), replace=False)
|
||||
|
||||
replay_data = np.empty((len(replay), 32, 13, 13))
|
||||
for i, n in enumerate(replay):
|
||||
f.seek(n * 32 * 13 * 13 + 4)
|
||||
arr = np.frombuffer(f.read(32 * 13 * 13 * 4), dtype=np.float32).reshape(32, 13, 13)
|
||||
replay_data[i] = arr
|
||||
|
||||
map_data: np.ndarray = replay_data.argmax(axis=1)
|
||||
buf = bytearray()
|
||||
buf.extend(struct.pack('>h', len(replay))) # Tensor count
|
||||
buf.extend(struct.pack('>b', H)) # Map height
|
||||
buf.extend(struct.pack('>b', W)) # Map width
|
||||
buf.extend(map_data.astype(np.int8).tobytes()) # Map tensor
|
||||
conn.sendall(buf)
|
||||
data.extend(parse_minamo_data(conn, replay_data))
|
||||
|
||||
# 把新的内容写入文件末尾
|
||||
to_write = np.random.choice(N, size=min(N, 100), replace=False)
|
||||
write_data = bytearray()
|
||||
for n in to_write:
|
||||
write_data.extend(prob_list[n].tobytes())
|
||||
|
||||
f.seek(0, 2) # 定位到文件末尾
|
||||
f.write(write_data)
|
||||
|
||||
f.seek(0) # 定位到文件开头
|
||||
f.write(struct.pack('>i', count + len(to_write)))
|
||||
f.flush() # 确保数据被刷新到磁盘
|
||||
|
||||
minamo_dataset.set_data(data)
|
||||
|
||||
# -------------------- 训练判别器
|
||||
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
|
||||
ginka.eval()
|
||||
@ -283,21 +321,43 @@ def train():
|
||||
total_loss = 0
|
||||
|
||||
for batch in tqdm(minamo_dataloader, leave=False, desc="Epoch Progress"):
|
||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = parse_minamo_batch(batch)
|
||||
map1, map2, vis_sim, topo_sim, graph1, graph2 = parse_minamo_batch(batch)
|
||||
batch_size = map1.shape[0]
|
||||
|
||||
if map1.shape[0] == 1:
|
||||
if batch_size == 1:
|
||||
continue
|
||||
|
||||
# 前向传播
|
||||
optimizer_minamo.zero_grad()
|
||||
vision_feat1, topo_feat1 = minamo(map1, graph1)
|
||||
vision_feat2, topo_feat2 = minamo(map2, graph2)
|
||||
vis_feat_real, topo_feat_real = minamo(map1, graph1)
|
||||
vis_feat_ref, topo_feat_ref = minamo(map2, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1)
|
||||
# 生成假数据
|
||||
with torch.no_grad():
|
||||
fake_feat = torch.randn((batch_size, 1024), device=device)
|
||||
fake_data = ginka(fake_feat)
|
||||
|
||||
# 创建插值样本
|
||||
alpha = torch.rand((batch_size, 1, 1, 1), device=device)
|
||||
interpolates = (alpha * map2 + (1 - alpha) * fake_data).requires_grad_(True)
|
||||
|
||||
vis_feat_fake, topo_feat_fake = minamo(fake_data)
|
||||
vis_feat_interp, topo_feat_interp = minamo(interpolates)
|
||||
|
||||
vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1)
|
||||
topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1)
|
||||
vis_pred_fake = F.cosine_similarity(vis_feat_fake, vis_feat_ref, dim=1).unsqueeze(-1)
|
||||
topo_pred_fake = F.cosine_similarity(topo_feat_fake, topo_feat_ref, dim=1).unsqueeze(-1)
|
||||
vis_pred_interp = F.cosine_similarity(vis_feat_interp, vis_feat_ref, dim=1).unsqueeze(-1)
|
||||
topo_pred_interp = F.cosine_similarity(topo_feat_interp, topo_feat_ref, dim=1).unsqueeze(-1)
|
||||
|
||||
# 计算相似度
|
||||
score_real = F.l1_loss(vis_pred_real, vis_sim) * VISION_ALPHA + F.l1_loss(topo_pred_real, topo_sim) * (1 - VISION_ALPHA)
|
||||
score_fake = vis_pred_fake * VISION_ALPHA + topo_pred_fake * (1 - VISION_ALPHA)
|
||||
score_interp = vis_pred_interp * VISION_ALPHA + topo_pred_interp * (1 - VISION_ALPHA)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion_minamo(vision_pred, topo_pred, vision_simi, topo_simi)
|
||||
loss = criterion.discriminator_loss(score_real, score_fake, score_interp)
|
||||
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
@ -310,21 +370,21 @@ def train():
|
||||
scheduler_minamo.step(epoch + 1)
|
||||
|
||||
# 每十轮推理一次验证集
|
||||
if (epoch + 1) % 5 == 0:
|
||||
if epoch + 1 == EPOCHS_MINAMO:
|
||||
minamo.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in tqdm(minamo_dataloader_val, leave=False, desc="Validating Minamo Model"):
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = parse_minamo_batch(val_batch)
|
||||
|
||||
vision_feat1, topo_feat1 = minamo(map1_val, graph1)
|
||||
vision_feat2, topo_feat2 = minamo(map2_val, graph2)
|
||||
vis_feat_real, topo_feat_real = minamo(map1_val, graph1)
|
||||
vis_feat_ref, topo_feat_ref = minamo(map2_val, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1)
|
||||
vis_pred_real = F.cosine_similarity(vis_feat_real, vis_feat_ref, dim=1).unsqueeze(-1)
|
||||
topo_pred_real = F.cosine_similarity(topo_feat_real, topo_feat_ref, dim=1).unsqueeze(-1)
|
||||
|
||||
# 计算损失
|
||||
loss_val = criterion_minamo(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
|
||||
loss_val = criterion_minamo(vis_pred_real, topo_pred_real, vision_simi_val, topo_simi_val)
|
||||
val_loss += loss_val.item()
|
||||
|
||||
avg_val_loss = val_loss / len(minamo_dataloader_val)
|
||||
|
||||
164
ginka/train_wgan.py
Normal file
164
ginka/train_wgan.py
Normal file
@ -0,0 +1,164 @@
|
||||
import argparse
|
||||
import os
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
import cv2
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .model.model import GinkaModel
|
||||
from .dataset import GinkaWGANDataset
|
||||
from .model.loss import WGANGinkaLoss
|
||||
from minamo.model.model import MinamoScoreModule
|
||||
from shared.graph import batch_convert_soft_map_to_graph
|
||||
from shared.image import matrix_to_image_cv
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
|
||||
BATCH_SIZE = 32
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/wgan", exist_ok=True)
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--state_ginka", type=str, default="result/ginka.pth")
|
||||
parser.add_argument("--state_minamo", type=str, default="result/minamo.pth")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
def clip_weights(model, clip_value=0.01):
|
||||
for param in model.parameters():
|
||||
param.data = torch.clamp(param.data, -clip_value, clip_value)
|
||||
|
||||
def train():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to train model.")
|
||||
|
||||
c_steps = 1
|
||||
g_steps = 3
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
ginka = GinkaModel()
|
||||
minamo = MinamoScoreModule()
|
||||
ginka.to(device)
|
||||
minamo.to(device)
|
||||
|
||||
dataset = GinkaWGANDataset(args.train, device)
|
||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
|
||||
|
||||
optimizer_ginka = optim.RMSprop(ginka.parameters(), lr=2e-4)
|
||||
optimizer_minamo = optim.RMSprop(minamo.parameters(), lr=1e-5)
|
||||
|
||||
criterion = WGANGinkaLoss()
|
||||
|
||||
# 用于生成图片
|
||||
tile_dict = dict()
|
||||
for file in os.listdir('tiles'):
|
||||
name = os.path.splitext(file)[0]
|
||||
tile_dict[name] = cv2.imread(f"tiles/{file}", cv2.IMREAD_UNCHANGED)
|
||||
|
||||
if args.resume:
|
||||
data = torch.load(args.state_ginka, map_location=device)
|
||||
ginka.load_state_dict(data["model_state"], strict=False)
|
||||
data = torch.load(args.state_minamo, map_location=device)
|
||||
minamo.load_state_dict(data["model_state"], strict=False)
|
||||
print("Train from loaded state.")
|
||||
|
||||
for epoch in tqdm(range(args.epochs), desc="GAN Training"):
|
||||
loss_total_minamo = torch.Tensor([0]).to(device)
|
||||
loss_total_ginka = torch.Tensor([0]).to(device)
|
||||
dis_total = torch.Tensor([0]).to(device)
|
||||
|
||||
for real_data in tqdm(dataloader, leave=False, desc="Epoch Progress"):
|
||||
batch_size = real_data.size(0)
|
||||
real_data = real_data.to(device)
|
||||
real_graph = batch_convert_soft_map_to_graph(real_data)
|
||||
|
||||
optimizer_ginka.zero_grad()
|
||||
|
||||
# ---------- 训练判别器
|
||||
for _ in range(c_steps):
|
||||
# 生成假样本
|
||||
optimizer_minamo.zero_grad()
|
||||
z = torch.randn(batch_size, 1024, device=device)
|
||||
fake_data = ginka(z)
|
||||
fake_data = fake_data.detach()
|
||||
|
||||
# 计算判别器输出
|
||||
# 反向传播
|
||||
dis, loss_d = criterion.discriminator_loss(minamo, real_data, real_graph, fake_data)
|
||||
loss_d.backward()
|
||||
# torch.nn.utils.clip_grad_norm_(minamo_vis.parameters(), max_norm=1.0)
|
||||
optimizer_minamo.step()
|
||||
|
||||
loss_total_minamo += loss_d
|
||||
dis_total += dis
|
||||
|
||||
# ---------- 训练生成器
|
||||
|
||||
for _ in range(g_steps):
|
||||
z1 = torch.randn(batch_size, 1024, device=device)
|
||||
z2 = torch.randn(batch_size, 1024, device=device)
|
||||
fake_softmax1, fakse_softmax2 = ginka(z1), ginka(z2)
|
||||
|
||||
loss_g = criterion.generator_loss(minamo, fake_softmax1, fakse_softmax2)
|
||||
loss_g.backward()
|
||||
optimizer_ginka.step()
|
||||
|
||||
loss_total_ginka += loss_g
|
||||
# tqdm.write(f"{dis.item():.12f}, {loss_d.item():.12f}, {loss_g.item():.12f}")
|
||||
|
||||
avg_loss_ginka = loss_total_ginka.item() / len(dataloader) / g_steps
|
||||
avg_loss_minamo = loss_total_minamo.item() / len(dataloader) / c_steps
|
||||
avg_dis = dis_total.item() / len(dataloader) / c_steps
|
||||
tqdm.write(
|
||||
f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Epoch: {epoch + 1} | Wasserstein Loss: {avg_dis:.8f} | Loss Ginka: {avg_loss_ginka:.8f} | Loss Minamo: {avg_loss_minamo:.8f}"
|
||||
)
|
||||
|
||||
if avg_dis < -9:
|
||||
g_steps = 7
|
||||
elif avg_dis < -6:
|
||||
g_steps = 5
|
||||
elif avg_dis < -3:
|
||||
g_steps = 3
|
||||
else:
|
||||
g_steps = 1
|
||||
|
||||
# 每五轮输出一次图片,并保存检查点
|
||||
if (epoch + 1) % 5 == 0:
|
||||
# 输出 20 张图片,每批次 4 张,一共五批
|
||||
idx = 0
|
||||
with torch.no_grad():
|
||||
for _ in range(5):
|
||||
z = torch.randn(4, 1024, device=device)
|
||||
output = ginka(z)
|
||||
|
||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
||||
for matrix in map_matrix:
|
||||
image = matrix_to_image_cv(matrix, tile_dict)
|
||||
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
||||
idx += 1
|
||||
|
||||
# 保存检查点
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
}, f"result/wgan/ginka-{epoch + 1}.pth")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
}, f"result/wgan/minamo-{epoch + 1}.pth")
|
||||
|
||||
print("Train ended.")
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
}, f"result/ginka.pth")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
}, f"result/minamo.pth")
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_num_threads(4)
|
||||
train()
|
||||
@ -1,7 +1,9 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .vision import MinamoVisionModel
|
||||
from .topo import MinamoTopoModel
|
||||
from shared.constant import VISION_WEIGHT, TOPO_WEIGHT
|
||||
|
||||
class MinamoModel(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
@ -16,3 +18,62 @@ class MinamoModel(nn.Module):
|
||||
topo_feat = self.topo_model(graph)
|
||||
|
||||
return vision_feat, topo_feat
|
||||
|
||||
class MinamoVisionScore(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
super().__init__()
|
||||
# 视觉相似度部分
|
||||
self.vision_model = MinamoVisionModel(tile_types)
|
||||
# 输出层
|
||||
self.vision_fc = nn.Sequential(
|
||||
nn.Linear(512, 2048),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(2048, 1)
|
||||
)
|
||||
|
||||
def forward(self, map):
|
||||
vision_feat = self.vision_model(map)
|
||||
vis_score = self.vision_fc(vision_feat)
|
||||
return vis_score
|
||||
|
||||
class MinamoTopoScore(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
super().__init__()
|
||||
# 拓扑相似度部分
|
||||
self.topo_model = MinamoTopoModel(tile_types)
|
||||
# 输出层
|
||||
self.topo_fc = nn.Sequential(
|
||||
nn.Linear(512, 2048),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(2048, 1)
|
||||
)
|
||||
|
||||
def forward(self, graph):
|
||||
topo_feat = self.topo_model(graph)
|
||||
topo_score = self.topo_fc(topo_feat)
|
||||
return topo_score
|
||||
|
||||
class MinamoScoreModule(nn.Module):
|
||||
def __init__(self, tile_types=32):
|
||||
super().__init__()
|
||||
self.topo_model = MinamoTopoModel(tile_types)
|
||||
self.vision_model = MinamoVisionModel(tile_types)
|
||||
# 输出层
|
||||
self.topo_fc = nn.Sequential(
|
||||
nn.Linear(512, 2048),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(2048, 1)
|
||||
)
|
||||
self.vision_fc = nn.Sequential(
|
||||
nn.Linear(512, 2048),
|
||||
nn.LeakyReLU(0.2),
|
||||
nn.Linear(2048, 1)
|
||||
)
|
||||
|
||||
def forward(self, map, graph):
|
||||
topo_feat = self.topo_model(graph)
|
||||
topo_score = self.topo_fc(topo_feat)
|
||||
vision_feat = self.vision_model(map)
|
||||
vision_score = self.vision_fc(vision_feat)
|
||||
score = VISION_WEIGHT * vision_score + TOPO_WEIGHT * topo_score
|
||||
return score, vision_score, topo_score
|
||||
|
||||
@ -1,56 +1,50 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GATConv, global_max_pool
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch_geometric.nn import GATConv, global_max_pool, GCNConv, global_mean_pool
|
||||
from torch_geometric.data import Data
|
||||
|
||||
class MinamoTopoModel(nn.Module):
|
||||
def __init__(
|
||||
self, tile_types=32, emb_dim=128, hidden_dim=128, out_dim=512, mlp_dim=512
|
||||
self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512, feat_dim=512
|
||||
):
|
||||
super().__init__()
|
||||
# 传入 softmax 概率值,直接映射
|
||||
self.input_proj = nn.Linear(tile_types, emb_dim)
|
||||
self.input_proj = nn.Sequential(
|
||||
nn.Linear(tile_types, emb_dim),
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
# 图卷积层
|
||||
self.conv1 = GATConv(emb_dim, hidden_dim*2, heads=8, dropout=0.2)
|
||||
self.conv2 = GATConv(hidden_dim*16, hidden_dim*2, heads=8)
|
||||
self.conv3 = GATConv(hidden_dim*16, hidden_dim*2, heads=8)
|
||||
self.conv4 = GATConv(hidden_dim*16, out_dim, heads=1)
|
||||
self.conv1 = GATConv(emb_dim, hidden_dim, heads=8)
|
||||
self.conv2 = GATConv(hidden_dim*8, hidden_dim, heads=8)
|
||||
self.conv3 = GATConv(hidden_dim*8, out_dim, heads=1)
|
||||
|
||||
# 正则化
|
||||
self.norm1 = nn.LayerNorm(hidden_dim*16)
|
||||
self.norm2 = nn.LayerNorm(hidden_dim*16)
|
||||
self.norm3 = nn.LayerNorm(hidden_dim*16)
|
||||
self.norm4 = nn.LayerNorm(out_dim)
|
||||
# self.norm1 = nn.LayerNorm(hidden_dim*8)
|
||||
# self.norm2 = nn.LayerNorm(hidden_dim*8)
|
||||
# self.norm3 = nn.LayerNorm(out_dim)
|
||||
|
||||
self.drop = nn.Dropout(0.3)
|
||||
|
||||
# 增强MLP
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(out_dim, mlp_dim),
|
||||
nn.Linear(out_dim, feat_dim),
|
||||
nn.LeakyReLU(0.2)
|
||||
)
|
||||
|
||||
def forward(self, graph: Data):
|
||||
x = self.input_proj(graph.x)
|
||||
|
||||
x = self.conv1(x, graph.edge_index)
|
||||
x = F.relu(self.norm1(x))
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
|
||||
x = self.conv2(x, graph.edge_index)
|
||||
x = F.relu(self.norm2(x))
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
|
||||
x = self.conv3(x, graph.edge_index)
|
||||
x = F.relu(self.norm3(x))
|
||||
|
||||
x = self.conv4(x, graph.edge_index)
|
||||
x = F.relu(self.norm4(x))
|
||||
x = F.leaky_relu(x, 0.2)
|
||||
|
||||
# 池化
|
||||
x = self.drop(x)
|
||||
x = global_max_pool(x, graph.batch)
|
||||
x = global_mean_pool(x, graph.batch)
|
||||
|
||||
topo_vec = self.fc(x)
|
||||
|
||||
# 归一化
|
||||
return F.normalize(topo_vec, p=2, dim=-1)
|
||||
return topo_vec
|
||||
|
||||
@ -1,14 +1,28 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torchvision.models import resnet18
|
||||
from torch.nn.utils import spectral_norm
|
||||
|
||||
class MinamoVisionModel(nn.Module):
|
||||
def __init__(self, tile_types=32, out_dim=512):
|
||||
def __init__(self, in_ch=32, out_dim=512):
|
||||
super().__init__()
|
||||
self.resnet = resnet18(num_classes=out_dim)
|
||||
self.resnet.conv1 = nn.Conv2d(tile_types, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.conv = nn.Sequential(
|
||||
spectral_norm(nn.Conv2d(in_ch, in_ch*2, 3, stride=2)), # 6*6
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
spectral_norm(nn.Conv2d(in_ch*2, in_ch*4, 3)), #4*4
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
spectral_norm(nn.Conv2d(in_ch*4, in_ch*8, 3)), # 2*2
|
||||
nn.LeakyReLU(0.2),
|
||||
|
||||
nn.Flatten()
|
||||
)
|
||||
self.fc = nn.Sequential(
|
||||
spectral_norm(nn.Linear(in_ch*8*2*2, out_dim))
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
vision_vec = self.resnet(x)
|
||||
return F.normalize(vision_vec, p=2, dim=-1) # 归一化
|
||||
x = self.conv(x)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
6
shared/constant.py
Normal file
6
shared/constant.py
Normal file
@ -0,0 +1,6 @@
|
||||
VIS_DIM = 512
|
||||
TOPO_DIM = 512
|
||||
FEAT_DIM = 1024
|
||||
|
||||
VISION_WEIGHT = 0
|
||||
TOPO_WEIGHT = 1
|
||||
@ -1,5 +1,6 @@
|
||||
import torch
|
||||
from torch_geometric.data import Data, Batch
|
||||
from torch_geometric.utils import add_self_loops
|
||||
|
||||
def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
|
||||
"""
|
||||
@ -44,6 +45,8 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
|
||||
src_feat = map_probs[:, edge_src // W, edge_src % W].T # [E, C]
|
||||
dst_feat = map_probs[:, edge_dst // W, edge_dst % W].T # [E, C]
|
||||
edge_attr = (src_feat + dst_feat) / 2 * edge_mask # [E, C]
|
||||
|
||||
edge_index, edge_attr = add_self_loops(edge_index, edge_attr)
|
||||
|
||||
return Data(
|
||||
x=node_features,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user