mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 14:31:11 +08:00
feat: 记录训练过程
This commit is contained in:
parent
6a1aeaa77e
commit
d7209a68a2
@ -87,8 +87,8 @@ function weisfeilerLehmanIteration(
|
||||
});
|
||||
weight *= decay;
|
||||
});
|
||||
// 把每个节点的原始标签也加上,权重使用最远权重再衰减2次,可以认为是资源重复率
|
||||
weight *= decay ** 2;
|
||||
// 把每个节点的原始标签也加上,权重使用最远权重再衰减1次,可以认为是资源重复率
|
||||
weight *= decay;
|
||||
nodes.forEach(node => {
|
||||
if (!numMap.has(node.originalLabel)) {
|
||||
numMap.set(node.originalLabel, weight);
|
||||
|
||||
@ -19,21 +19,26 @@ from shared.image import matrix_to_image_cv
|
||||
|
||||
BATCH_SIZE = 32
|
||||
EPOCHS_GINKA = 30
|
||||
EPOCHS_MINAMO = 10
|
||||
EPOCHS_MINAMO = 5
|
||||
SOCKET_PATH = "./tmp/ginka_uds"
|
||||
LOSS_PATH = "result/gan/a-loss.txt"
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||
os.makedirs("result/gan", exist_ok=True)
|
||||
os.makedirs("tmp", exist_ok=True)
|
||||
|
||||
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
||||
f.write(f"---------- {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ----------\n")
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="training codes")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--from_state", type=str, default="result/ginka.pth")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
parser.add_argument("--validate", type=str, default='ginka-eval.json')
|
||||
parser.add_argument("--from_cycle", type=int, default=2)
|
||||
parser.add_argument("--from_cycle", type=int, default=0)
|
||||
parser.add_argument("--to_cycle", type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
@ -141,8 +146,8 @@ def train():
|
||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion_ginka = GinkaLoss(minamo)
|
||||
|
||||
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-4)
|
||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6)
|
||||
criterion_minamo = MinamoLoss()
|
||||
|
||||
# 用于生成图片
|
||||
@ -238,6 +243,9 @@ def train():
|
||||
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
|
||||
|
||||
tqdm.write(f"Cycle {cycle} Ginka train ended.")
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
}, f"result/gan/ginka-{cycle}.pth")
|
||||
torch.save({
|
||||
"model_state": ginka.state_dict()
|
||||
}, f"result/ginka.pth")
|
||||
@ -256,6 +264,17 @@ def train():
|
||||
conn.sendall(buf)
|
||||
data = parse_minamo_data(conn, prob_list)
|
||||
minamo_dataset.set_data(data)
|
||||
vis_sim = 0
|
||||
topo_sim = 0
|
||||
for _, _, vis, topo, _ in data:
|
||||
vis_sim += vis
|
||||
topo_sim += topo
|
||||
|
||||
vis_sim /= len(data)
|
||||
topo_sim /= len(data)
|
||||
|
||||
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
||||
f.write(f'Cycle {cycle} | Ginka Vision Similarity: {vis_sim:.12f} | Ginka Topo Similarity: {topo_sim:.12f} | Ginka Loss: {avg_val_loss:.12f}')
|
||||
|
||||
# -------------------- 训练判别器
|
||||
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
|
||||
@ -315,9 +334,14 @@ def train():
|
||||
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
|
||||
|
||||
tqdm.write(f"Cycle {cycle} Minamo train ended.")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
}, f"result/gan/minamo-{cycle}.pth")
|
||||
torch.save({
|
||||
"model_state": minamo.state_dict()
|
||||
}, f"result/minamo.pth")
|
||||
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
||||
f.write(f' | Minamo: {avg_val_loss:.12f}\n')
|
||||
|
||||
print("Train ended.")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user