From 6a1aeaa77e8268d22cd9e0d74289a28cfbd7bbce Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 3 Apr 2025 12:12:13 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=B5=AE=E7=82=B9=E6=95=B0=E8=A2=AB?= =?UTF-8?q?=E6=84=8F=E5=A4=96=E8=BD=AC=E6=88=90=20LongTensor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/src/gan.ts | 2 +- data/src/topology/similarity.ts | 3 ++- ginka/dataset.py | 13 +++++---- ginka/model/loss.py | 7 ++--- ginka/train_gan.py | 47 ++++++++++++++++++--------------- minamo/model/loss.py | 2 +- shared/graph.py | 31 ---------------------- 7 files changed, 42 insertions(+), 63 deletions(-) diff --git a/data/src/gan.ts b/data/src/gan.ts index 2c2dc53..4edd797 100644 --- a/data/src/gan.ts +++ b/data/src/gan.ts @@ -35,7 +35,7 @@ function generateGANData( map: number[][] ) { const id2 = `$${id++}`; - const toTrain = chooseFrom(keys, 30); + const toTrain = chooseFrom(keys, 4); const data = toTrain.map(v => { const floor = refer.get(v); if (!floor) return []; diff --git a/data/src/topology/similarity.ts b/data/src/topology/similarity.ts index 9275d3c..c7fac25 100644 --- a/data/src/topology/similarity.ts +++ b/data/src/topology/similarity.ts @@ -87,7 +87,8 @@ function weisfeilerLehmanIteration( }); weight *= decay; }); - // 把每个节点的原始标签也加上,权重使用最远权重,可以认为是资源重复率 + // 把每个节点的原始标签也加上,权重使用最远权重再衰减2次,可以认为是资源重复率 + weight *= decay ** 2; nodes.forEach(node => { if (!numMap.has(node.originalLabel)) { numMap.set(node.originalLabel, weight); diff --git a/ginka/dataset.py b/ginka/dataset.py index bca3841..478d779 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -37,7 +37,10 @@ class GinkaDataset(Dataset): item = self.data[idx] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - target_smooth = random_smooth_onehot(target) + min_main = random.uniform(0.75, 0.9) + max_main = random.uniform(0.9, 1) + epsilon = random.uniform(0, 0.25) + target_smooth = random_smooth_onehot(target, min_main, max_main, epsilon) graph = differentiable_convert_to_data(target_smooth).to(self.device) target = target.to(self.device) vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph) @@ -67,12 +70,12 @@ class MinamoGANDataset(Dataset): item = self.data[idx] map1, map2, vis_sim, topo_sim, review = item - map1 = torch.LongTensor(map1) - map2 = torch.LongTensor(map2) # 检查是否有 review 标签,没有的话说明是概率分布,不需要任何转换 if review: - map1 = F.one_hot(map1, num_classes=32).permute(2, 0, 1).float() # [32, H, W] - map2 = F.one_hot(map2, num_classes=32).permute(2, 0, 1).float() # [32, H, W] + map1 = F.one_hot(torch.LongTensor(map1), num_classes=32).permute(2, 0, 1).float() # [32, H, W] + else: + map1 = torch.FloatTensor(map1) + map2 = F.one_hot(torch.LongTensor(map2), num_classes=32).permute(2, 0, 1).float() # [32, H, W] min_main = random.uniform(0.75, 0.9) max_main = random.uniform(0.9, 1) diff --git a/ginka/model/loss.py b/ginka/model/loss.py index 272e6d4..38bdcf1 100644 --- a/ginka/model/loss.py +++ b/ginka/model/loss.py @@ -249,9 +249,10 @@ class GinkaLoss(nn.Module): graph = batch_convert_soft_map_to_graph(pred) pred_vision_feat, pred_topo_feat = self.minamo(pred, graph) - vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=-1) - topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=-1) - minamo_sim = 0.2 * vision_sim + 0.8 * topo_sim + vision_sim = F.cosine_similarity(pred_vision_feat, target_vision_feat, dim=1) + topo_sim = F.cosine_similarity(pred_topo_feat, target_topo_feat, dim=1) + minamo_sim = 0 * vision_sim + 1 * topo_sim + # tqdm.write(f"{vision_sim.mean().item():.12f}, {topo_sim.mean().item():.12f}") minamo_loss = (1.0 - minamo_sim).mean() tqdm.write( diff --git a/ginka/train_gan.py b/ginka/train_gan.py index 416905d..a18561c 100644 --- a/ginka/train_gan.py +++ b/ginka/train_gan.py @@ -19,7 +19,7 @@ from shared.image import matrix_to_image_cv BATCH_SIZE = 32 EPOCHS_GINKA = 30 -EPOCHS_MINAMO = 15 +EPOCHS_MINAMO = 10 SOCKET_PATH = "./tmp/ginka_uds" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -40,9 +40,9 @@ def parse_arguments(): def parse_ginka_batch(batch): target = batch["target"].to(device) - target_vision_feat = batch["target_vision_feat"].to(device) - target_topo_feat = batch["target_topo_feat"].to(device) - feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) + target_vision_feat = batch["target_vision_feat"].to(device).squeeze(1) + target_topo_feat = batch["target_topo_feat"].to(device).squeeze(1) + feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=1).to(device) return target, target_vision_feat, target_topo_feat, feat_vec @@ -133,8 +133,8 @@ def train(): minamo_dataset_val = MinamoGANDataset("datasets/minamo-eval-1.json") ginka_dataloader = DataLoader(ginka_dataset, batch_size=BATCH_SIZE, shuffle=True) ginka_dataloader_val = DataLoader(ginka_dataset_val, batch_size=BATCH_SIZE, shuffle=True) - minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE, shuffle=True) - minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE, shuffle=True) + minamo_dataloader = DataLoader(minamo_dataset, batch_size=BATCH_SIZE // 2, shuffle=True) + minamo_dataloader_val = DataLoader(minamo_dataset_val, batch_size=BATCH_SIZE // 2, shuffle=True) # 设定优化器与调度器 optimizer_ginka = optim.AdamW(ginka.parameters(), lr=1e-3) @@ -142,7 +142,7 @@ def train(): criterion_ginka = GinkaLoss(minamo) optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3) - scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6) + scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6) criterion_minamo = MinamoLoss() # 用于生成图片 @@ -158,20 +158,18 @@ def train(): server.bind(SOCKET_PATH) server.listen(1) - print("Waiting for client connection...") - conn, _ = server.accept() - print("Client connected.") - if args.resume: data = torch.load(args.from_state, map_location=device) ginka.load_state_dict(data["model_state"], strict=False) - if args.load_optim: - optimizer_ginka.load_state_dict(data["optimizer_state"]) print("Train from loaded state.") else: # 从头开始训练的话,初始时先把 minamo 损失值权重改为 0 criterion_ginka.weight[0] = 0.0 + + print("Waiting for client connection...") + conn, _ = server.accept() + print("Client connected.") for cycle in tqdm(range(args.from_cycle, args.to_cycle), desc="Total Progress"): # -------------------- 训练生成器 @@ -217,10 +215,7 @@ def train(): loss_val += losses.item() if epoch + 1 == EPOCHS_GINKA: # 最后一次验证的时候顺带生成图片 - prob = output_softmax.cpu().numpy() - prob_list = np.concatenate((prob_list, prob), axis=0) map_matrix = torch.argmax(output, dim=1).cpu().numpy() - gen_list = np.concatenate((gen_list, map_matrix), axis=0) for matrix in map_matrix: image = matrix_to_image_cv(matrix, tile_dict) cv2.imwrite(f"result/ginka_img/{idx}.png", image) @@ -231,6 +226,16 @@ def train(): torch.save({ "model_state": ginka.state_dict() }, f"result/ginka_checkpoint/{epoch + 1}.pth") + + # 使用训练集生成 minamo 训练数据,更准确 + with torch.no_grad(): + for batch in ginka_dataloader: + target, target_vision_feat, target_topo_feat, feat_vec = parse_ginka_batch(batch) + output, output_softmax = ginka(feat_vec) + prob = output_softmax.cpu().numpy() + prob_list = np.concatenate((prob_list, prob), axis=0) + map_matrix = torch.argmax(output, dim=1).cpu().numpy() + gen_list = np.concatenate((gen_list, map_matrix), axis=0) tqdm.write(f"Cycle {cycle} Ginka train ended.") torch.save({ @@ -269,8 +274,8 @@ def train(): vision_feat1, topo_feat1 = minamo(map1, graph1) vision_feat2, topo_feat2 = minamo(map2, graph2) - vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) - topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) + vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1) + topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1) # 计算损失 loss = criterion_minamo(vision_pred, topo_pred, vision_simi, topo_simi) @@ -296,8 +301,8 @@ def train(): vision_feat1, topo_feat1 = minamo(map1_val, graph1) vision_feat2, topo_feat2 = minamo(map2_val, graph2) - vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, -1).unsqueeze(-1) - topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1) + vision_pred = F.cosine_similarity(vision_feat1, vision_feat2, dim=1).unsqueeze(-1) + topo_pred = F.cosine_similarity(topo_feat1, topo_feat2, dim=1).unsqueeze(-1) # 计算损失 loss_val = criterion_minamo(vision_pred, topo_pred, vision_simi_val, topo_simi_val) @@ -312,7 +317,7 @@ def train(): tqdm.write(f"Cycle {cycle} Minamo train ended.") torch.save({ "model_state": minamo.state_dict() - }, f"result/ginka.pth") + }, f"result/minamo.pth") print("Train ended.") diff --git a/minamo/model/loss.py b/minamo/model/loss.py index ffcf575..5fe818f 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -10,8 +10,8 @@ class MinamoLoss(nn.Module): def forward(self, vis_pred, topo_pred, vis_true, topo_true): # print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape) - # tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f}") vis_loss = self.loss(vis_pred, vis_true) topo_loss = self.loss(topo_pred, topo_true) + # tqdm.write(f"{vis_pred[0].item():.12f}, {vis_true[0].item():.12f}, {topo_pred[0].item():.12f}, {topo_true[0].item():.12f} | {vis_loss.item():.12f}, {topo_loss.item():.12f}") # print(vis_loss.item(), topo_loss.item()) return self.vision_weight * vis_loss + self.topo_weight * topo_loss diff --git a/shared/graph.py b/shared/graph.py index c29eb9c..4ba5adc 100644 --- a/shared/graph.py +++ b/shared/graph.py @@ -52,37 +52,6 @@ def differentiable_convert_to_data(map_probs: torch.Tensor) -> Data: num_nodes=N ) -def convert_soft_map_to_graph(map_probs: torch.Tensor): - """ - 直接使用 Softmax 概率构建 soft 图结构 - """ - C, H, W = map_probs.shape # [32, H, W] - N = H * W - device = map_probs.device - - # 计算 soft 节点特征 - node_features = map_probs.view(C, N).T # [N, C] - - # 计算 soft 邻接边(基于 soft 权重) - edge_list = [] - for r in range(H): - for c in range(W): - node = r * W + c - if c + 1 < W: - right = node + 1 - edge_list.append([node, right]) - if r + 1 < H: - down = node + W - edge_list.append([node, down]) - - edge_index = torch.tensor(edge_list).t().to(device) - - # 计算 soft 边权重(基于 Softmax 概率) - soft_edge_weight = (map_probs[:, edge_index[0] // W, edge_index[0] % W] + - map_probs[:, edge_index[1] // W, edge_index[1] % W]) / 2 - - return Data(x=node_features, edge_index=edge_index, edge_attr=soft_edge_weight) - def batch_convert_soft_map_to_graph(batch_map_probs): """ 处理 batch 维度,将 [B, C, H, W] 转换为批量图结构 Batch