feat: 记录训练过程

This commit is contained in:
unanmed 2025-04-03 13:36:05 +08:00
parent 6a1aeaa77e
commit d7209a68a2
2 changed files with 30 additions and 6 deletions

View File

@ -87,8 +87,8 @@ function weisfeilerLehmanIteration(
});
weight *= decay;
});
// 把每个节点的原始标签也加上,权重使用最远权重再衰减2次,可以认为是资源重复率
weight *= decay ** 2;
// 把每个节点的原始标签也加上,权重使用最远权重再衰减1次,可以认为是资源重复率
weight *= decay;
nodes.forEach(node => {
if (!numMap.has(node.originalLabel)) {
numMap.set(node.originalLabel, weight);

View File

@ -19,21 +19,26 @@ from shared.image import matrix_to_image_cv
BATCH_SIZE = 32
EPOCHS_GINKA = 30
EPOCHS_MINAMO = 10
EPOCHS_MINAMO = 5
SOCKET_PATH = "./tmp/ginka_uds"
LOSS_PATH = "result/gan/a-loss.txt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/ginka_checkpoint", exist_ok=True)
os.makedirs("result/gan", exist_ok=True)
os.makedirs("tmp", exist_ok=True)
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
f.write(f"---------- {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ----------\n")
def parse_arguments():
parser = argparse.ArgumentParser(description="training codes")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--from_state", type=str, default="result/ginka.pth")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
parser.add_argument("--validate", type=str, default='ginka-eval.json')
parser.add_argument("--from_cycle", type=int, default=2)
parser.add_argument("--from_cycle", type=int, default=0)
parser.add_argument("--to_cycle", type=int, default=100)
args = parser.parse_args()
return args
@ -141,8 +146,8 @@ def train():
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
criterion_ginka = GinkaLoss(minamo)
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6)
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-4)
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6)
criterion_minamo = MinamoLoss()
# 用于生成图片
@ -238,6 +243,9 @@ def train():
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
tqdm.write(f"Cycle {cycle} Ginka train ended.")
torch.save({
"model_state": ginka.state_dict()
}, f"result/gan/ginka-{cycle}.pth")
torch.save({
"model_state": ginka.state_dict()
}, f"result/ginka.pth")
@ -256,6 +264,17 @@ def train():
conn.sendall(buf)
data = parse_minamo_data(conn, prob_list)
minamo_dataset.set_data(data)
vis_sim = 0
topo_sim = 0
for _, _, vis, topo, _ in data:
vis_sim += vis
topo_sim += topo
vis_sim /= len(data)
topo_sim /= len(data)
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
f.write(f'Cycle {cycle} | Ginka Vision Similarity: {vis_sim:.12f} | Ginka Topo Similarity: {topo_sim:.12f} | Ginka Loss: {avg_val_loss:.12f}')
# -------------------- 训练判别器
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
@ -315,9 +334,14 @@ def train():
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
tqdm.write(f"Cycle {cycle} Minamo train ended.")
torch.save({
"model_state": minamo.state_dict()
}, f"result/gan/minamo-{cycle}.pth")
torch.save({
"model_state": minamo.state_dict()
}, f"result/minamo.pth")
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
f.write(f' | Minamo: {avg_val_loss:.12f}\n')
print("Train ended.")