From d7209a68a2aa0f9d97f438627fb765b0c09f2bda Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 3 Apr 2025 13:36:05 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=AE=B0=E5=BD=95=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E8=BF=87=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- data/src/topology/similarity.ts | 4 ++-- ginka/train_gan.py | 32 ++++++++++++++++++++++++++++---- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/data/src/topology/similarity.ts b/data/src/topology/similarity.ts index c7fac25..abb345b 100644 --- a/data/src/topology/similarity.ts +++ b/data/src/topology/similarity.ts @@ -87,8 +87,8 @@ function weisfeilerLehmanIteration( }); weight *= decay; }); - // 把每个节点的原始标签也加上,权重使用最远权重再衰减2次,可以认为是资源重复率 - weight *= decay ** 2; + // 把每个节点的原始标签也加上,权重使用最远权重再衰减1次,可以认为是资源重复率 + weight *= decay; nodes.forEach(node => { if (!numMap.has(node.originalLabel)) { numMap.set(node.originalLabel, weight); diff --git a/ginka/train_gan.py b/ginka/train_gan.py index a18561c..ee9ba08 100644 --- a/ginka/train_gan.py +++ b/ginka/train_gan.py @@ -19,21 +19,26 @@ from shared.image import matrix_to_image_cv BATCH_SIZE = 32 EPOCHS_GINKA = 30 -EPOCHS_MINAMO = 10 +EPOCHS_MINAMO = 5 SOCKET_PATH = "./tmp/ginka_uds" +LOSS_PATH = "result/gan/a-loss.txt" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) os.makedirs("result/ginka_checkpoint", exist_ok=True) +os.makedirs("result/gan", exist_ok=True) os.makedirs("tmp", exist_ok=True) +with open(LOSS_PATH, 'a', encoding='utf-8') as f: + f.write(f"---------- {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ----------\n") + def parse_arguments(): parser = argparse.ArgumentParser(description="training codes") parser.add_argument("--resume", type=bool, default=False) parser.add_argument("--from_state", type=str, default="result/ginka.pth") parser.add_argument("--train", type=str, default="ginka-dataset.json") parser.add_argument("--validate", type=str, default='ginka-eval.json') - parser.add_argument("--from_cycle", type=int, default=2) + parser.add_argument("--from_cycle", type=int, default=0) parser.add_argument("--to_cycle", type=int, default=100) args = parser.parse_args() return args @@ -141,8 +146,8 @@ def train(): scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6) criterion_ginka = GinkaLoss(minamo) - optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3) - scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6) + optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-4) + scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6) criterion_minamo = MinamoLoss() # 用于生成图片 @@ -238,6 +243,9 @@ def train(): gen_list = np.concatenate((gen_list, map_matrix), axis=0) tqdm.write(f"Cycle {cycle} Ginka train ended.") + torch.save({ + "model_state": ginka.state_dict() + }, f"result/gan/ginka-{cycle}.pth") torch.save({ "model_state": ginka.state_dict() }, f"result/ginka.pth") @@ -256,6 +264,17 @@ def train(): conn.sendall(buf) data = parse_minamo_data(conn, prob_list) minamo_dataset.set_data(data) + vis_sim = 0 + topo_sim = 0 + for _, _, vis, topo, _ in data: + vis_sim += vis + topo_sim += topo + + vis_sim /= len(data) + topo_sim /= len(data) + + with open(LOSS_PATH, 'a', encoding='utf-8') as f: + f.write(f'Cycle {cycle} | Ginka Vision Similarity: {vis_sim:.12f} | Ginka Topo Similarity: {topo_sim:.12f} | Ginka Loss: {avg_val_loss:.12f}') # -------------------- 训练判别器 for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"): @@ -315,9 +334,14 @@ def train(): }, f"result/minamo_checkpoint/{epoch + 1}.pth") tqdm.write(f"Cycle {cycle} Minamo train ended.") + torch.save({ + "model_state": minamo.state_dict() + }, f"result/gan/minamo-{cycle}.pth") torch.save({ "model_state": minamo.state_dict() }, f"result/minamo.pth") + with open(LOSS_PATH, 'a', encoding='utf-8') as f: + f.write(f' | Minamo: {avg_val_loss:.12f}\n') print("Train ended.")