mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 22:41:14 +08:00
fix: 浮点数被意外转成 LongTensor
This commit is contained in:
parent
325eb599c3
commit
6a1aeaa77e
@ -35,7 +35,7 @@ function generateGANData(
|
||||
map: number[][]
|
||||
) {
|
||||
const id2 = `$${id++}`;
|
||||
const toTrain = chooseFrom(keys, 30);
|
||||
const toTrain = chooseFrom(keys, 4);
|
||||
const data = toTrain.map<MinamoTrainData[]>(v => {
|
||||
const floor = refer.get(v);
|
||||
if (!floor) return [];
|
||||
|
||||
@ -87,7 +87,8 @@ function weisfeilerLehmanIteration(
|
||||
});
|
||||
weight *= decay;
|
||||
});
|
||||
// 把每个节点的原始标签也加上,权重使用最远权重,可以认为是资源重复率
|
||||
// 把每个节点的原始标签也加上,权重使用最远权重再衰减2次,可以认为是资源重复率
|
||||
weight *= decay ** 2;
|
||||
nodes.forEach(node => {
|
||||
if (!numMap.has(node.originalLabel)) {
|
||||
numMap.set(node.originalLabel, weight);
|
||||
|
||||
@ -37,7 +37,10 @@ class GinkaDataset(Dataset):
|
||||
item = self.data[idx]
|
||||
|
||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
target_smooth = random_smooth_onehot(target)
|
||||
min_main = random.uniform(0.75, 0.9)
|
||||
max_main = random.uniform(0.9, 1)
|
||||
epsilon = random.uniform(0, 0.25)
|
||||
target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon)
|
||||
graph = differentiable_convert_to_data(target_smooth).to(self.device)
|
||||
target = target.to(self.device)
|
||||
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
||||
@ -67,12 +70,12 @@ class MinamoGANDataset(Dataset):
|
||||
item = self.data[idx]
|
||||
|
||||
map1, map2, vis_sim, topo_sim, review = item
|
||||
map1 = torch.LongTensor(map1)
|
||||
map2 = torch.LongTensor(map2)
|
||||
# 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换
|
||||
if review:
|
||||
map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
map2 = F.one_hot(map2, num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
map1 = F.one_hot(torch.LongTensor(map1), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
else:
|
||||
map1 = torch.FloatTensor(map1)
|
||||
map2 = F.one_hot(torch.LongTensor(map2), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
|
||||
min_main = random.uniform(0.75, 0.9)
|
||||
max_main = random.uniform(0.9, 1)
|
||||
|
||||
@ -249,9 +249,10 @@ class GinkaLoss(nn.Module):
|
||||
graph = batch_convert_soft_map_to_graph(pred)
|
||||
pred_vision_feat, pred_topo_feat = self.minamo(pred, graph)
|
||||
|
||||
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1)
|
||||
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1)
|
||||
minamo_sim = 0.2 * vision_sim + 0.8 * topo_sim
|
||||
vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=1)
|
||||
topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=1)
|
||||
minamo_sim = 0 * vision_sim + 1 * topo_sim
|
||||
# tqdm.write(f"{vision_sim.mean().item():.12f}, {topo_sim.mean().item():.12f}")
|
||||
minamo_loss = (1.0 - minamo_sim).mean()
|
||||
|
||||
tqdm.write(
|
||||
|
||||
@ -19,7 +19,7 @@ from shared.image import matrix_to_image_cv
|
||||
|
||||
BATCH_SIZE = 32
|
||||
EPOCHS_GINKA = 30
|
||||
EPOCHS_MINAMO = 15
|
||||
EPOCHS_MINAMO = 10
|
||||
SOCKET_PATH = "./tmp/ginka_uds"
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
@ -40,9 +40,9 @@ def parse_arguments():
|
||||
|
||||
def parse_ginka_batch(batch):
|
||||
target = batch["target"].to(device)
|
||||
target_vision_feat = batch["target_vision_feat"].to(device)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
target_vision_feat = batch["target_vision_feat"].to(device).squeeze(1)
|
||||
target_topo_feat = batch["target_topo_feat"].to(device).squeeze(1)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=1).to(device)
|
||||
|
||||
return target, target_vision_feat, target_topo_feat, feat_vec
|
||||
|
||||
@ -133,8 +133,8 @@ def train():
|
||||
minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json")
|
||||
ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE, shuffle=True)
|
||||
minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE // 2, shuffle=True)
|
||||
minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE // 2, shuffle=True)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer_ginka = optim.AdamW(ginka.parameters(), lr=1e-3)
|
||||
@ -142,7 +142,7 @@ def train():
|
||||
criterion_ginka = GinkaLoss(minamo)
|
||||
|
||||
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6)
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion_minamo = MinamoLoss()
|
||||
|
||||
# 用于生成图片
|
||||
@ -158,20 +158,18 @@ def train():
|
||||
server.bind(SOCKET_PATH)
|
||||
server.listen(1)
|
||||
|
||||
print("Waiting for client connection...")
|
||||
conn, _ = server.accept()
|
||||
print("Client connected.")
|
||||
|
||||
if args.resume:
|
||||
data = torch.load(args.from_state, map_location=device)
|
||||
ginka.load_state_dict(data["model_state"], strict=False)
|
||||
if args.load_optim:
|
||||
optimizer_ginka.load_state_dict(data["optimizer_state"])
|
||||
print("Train from loaded state.")
|
||||
|
||||
else:
|
||||
# 从头开始训练的话,初始时先把 minamo 损失值权重改为 0
|
||||
criterion_ginka.weight[0] = 0.0
|
||||
|
||||
print("Waiting for client connection...")
|
||||
conn, _ = server.accept()
|
||||
print("Client connected.")
|
||||
|
||||
for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"):
|
||||
# -------------------- 训练生成器
|
||||
@ -217,10 +215,7 @@ def train():
|
||||
loss_val += losses.item()
|
||||
if epoch + 1 == EPOCHS_GINKA:
|
||||
# 最后一次验证的时候顺带生成图片
|
||||
prob = output_softmax.cpu().numpy()
|
||||
prob_list = np.concatenate((prob_list, prob), axis=0)
|
||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
||||
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
|
||||
for matrix in map_matrix:
|
||||
image = matrix_to_image_cv(matrix, tile_dict)
|
||||
cv2.imwrite(f"result/ginka_img/{idx}.png", image)
|
||||
@ -231,6 +226,16 @@ def train():
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
}, f"result/ginka_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
# 使用训练集生成 minamo 训练数据,更准确
|
||||
with torch.no_grad():
|
||||
for batch in ginka_dataloader:
|
||||
target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch)
|
||||
output, output_softmax = ginka(feat_vec)
|
||||
prob = output_softmax.cpu().numpy()
|
||||
prob_list = np.concatenate((prob_list, prob), axis=0)
|
||||
map_matrix = torch.argmax(output, dim=1).cpu().numpy()
|
||||
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
|
||||
|
||||
tqdm.write(f"Cycle {cycle} Ginka train ended.")
|
||||
torch.save({
|
||||
@ -269,8 +274,8 @@ def train():
|
||||
vision_feat1, topo_feat1 = minamo(map1, graph1)
|
||||
vision_feat2, topo_feat2 = minamo(map2, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion_minamo(vision_pred, topo_pred, vision_simi, topo_simi)
|
||||
@ -296,8 +301,8 @@ def train():
|
||||
vision_feat1, topo_feat1 = minamo(map1_val, graph1)
|
||||
vision_feat2, topo_feat2 = minamo(map2_val, graph2)
|
||||
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1)
|
||||
topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1)
|
||||
|
||||
# 计算损失
|
||||
loss_val = criterion_minamo(vision_pred, topo_pred, vision_simi_val, topo_simi_val)
|
||||
@ -312,7 +317,7 @@ def train():
|
||||
tqdm.write(f"Cycle {cycle} Minamo train ended.")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
}, f"result/ginka.pth")
|
||||
}, f"result/minamo.pth")
|
||||
|
||||
print("Train ended.")
|
||||
|
||||
|
||||
@ -10,8 +10,8 @@ class MinamoLoss(nn.Module):
|
||||
|
||||
def forward(self, vis_pred, topo_pred, vis_true, topo_true):
|
||||
# print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape)
|
||||
# tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f}")
|
||||
vis_loss = self.loss(vis_pred, vis_true)
|
||||
topo_loss = self.loss(topo_pred, topo_true)
|
||||
# tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f} | {vis_loss.item():.12f}, {topo_loss.item():.12f}")
|
||||
# print(vis_loss.item(), topo_loss.item())
|
||||
return self.vision_weight * vis_loss + self.topo_weight * topo_loss
|
||||
|
||||
@ -52,37 +52,6 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data:
|
||||
num_nodes=N
|
||||
)
|
||||
|
||||
def convert_soft_map_to_graph(map_probs: torch.Tensor):
|
||||
"""
|
||||
直接使用 Softmax 概率构建 soft 图结构
|
||||
"""
|
||||
C, H, W = map_probs.shape # [32, H, W]
|
||||
N = H * W
|
||||
device = map_probs.device
|
||||
|
||||
# 计算 soft 节点特征
|
||||
node_features = map_probs.view(C, N).T # [N, C]
|
||||
|
||||
# 计算 soft 邻接边(基于 soft 权重)
|
||||
edge_list = []
|
||||
for r in range(H):
|
||||
for c in range(W):
|
||||
node = r * W + c
|
||||
if c + 1 < W:
|
||||
right = node + 1
|
||||
edge_list.append([node, right])
|
||||
if r + 1 < H:
|
||||
down = node + W
|
||||
edge_list.append([node, down])
|
||||
|
||||
edge_index = torch.tensor(edge_list).t().to(device)
|
||||
|
||||
# 计算 soft 边权重(基于 Softmax 概率)
|
||||
soft_edge_weight = (map_probs[:, edge_index[0] // W, edge_index[0] % W] +
|
||||
map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2
|
||||
|
||||
return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight)
|
||||
|
||||
def batch_convert_soft_map_to_graph(batch_map_probs):
|
||||
"""
|
||||
处理 batch 维度,将 [B, C, H, W] 转换为批量图结构 Batch
|
||||
|
||||
Loading…
Reference in New Issue
Block a user