mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-23 03:51:11 +08:00
feat: 记录训练过程
This commit is contained in:
parent
6a1aeaa77e
commit
d7209a68a2
@ -87,8 +87,8 @@ function weisfeilerLehmanIteration(
|
|||||||
});
|
});
|
||||||
weight *= decay;
|
weight *= decay;
|
||||||
});
|
});
|
||||||
// 把每个节点的原始标签也加上,权重使用最远权重再衰减2次,可以认为是资源重复率
|
// 把每个节点的原始标签也加上,权重使用最远权重再衰减1次,可以认为是资源重复率
|
||||||
weight *= decay ** 2;
|
weight *= decay;
|
||||||
nodes.forEach(node => {
|
nodes.forEach(node => {
|
||||||
if (!numMap.has(node.originalLabel)) {
|
if (!numMap.has(node.originalLabel)) {
|
||||||
numMap.set(node.originalLabel, weight);
|
numMap.set(node.originalLabel, weight);
|
||||||
|
|||||||
@ -19,21 +19,26 @@ from shared.image import matrix_to_image_cv
|
|||||||
|
|
||||||
BATCH_SIZE = 32
|
BATCH_SIZE = 32
|
||||||
EPOCHS_GINKA = 30
|
EPOCHS_GINKA = 30
|
||||||
EPOCHS_MINAMO = 10
|
EPOCHS_MINAMO = 5
|
||||||
SOCKET_PATH = "./tmp/ginka_uds"
|
SOCKET_PATH = "./tmp/ginka_uds"
|
||||||
|
LOSS_PATH = "result/gan/a-loss.txt"
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
os.makedirs("result/ginka_checkpoint", exist_ok=True)
|
||||||
|
os.makedirs("result/gan", exist_ok=True)
|
||||||
os.makedirs("tmp", exist_ok=True)
|
os.makedirs("tmp", exist_ok=True)
|
||||||
|
|
||||||
|
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(f"---------- {datetime.now().strftime('%Y-%m-%d %H:%M:%S')} ----------\n")
|
||||||
|
|
||||||
def parse_arguments():
|
def parse_arguments():
|
||||||
parser = argparse.ArgumentParser(description="training codes")
|
parser = argparse.ArgumentParser(description="training codes")
|
||||||
parser.add_argument("--resume", type=bool, default=False)
|
parser.add_argument("--resume", type=bool, default=False)
|
||||||
parser.add_argument("--from_state", type=str, default="result/ginka.pth")
|
parser.add_argument("--from_state", type=str, default="result/ginka.pth")
|
||||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||||
parser.add_argument("--validate", type=str, default='ginka-eval.json')
|
parser.add_argument("--validate", type=str, default='ginka-eval.json')
|
||||||
parser.add_argument("--from_cycle", type=int, default=2)
|
parser.add_argument("--from_cycle", type=int, default=0)
|
||||||
parser.add_argument("--to_cycle", type=int, default=100)
|
parser.add_argument("--to_cycle", type=int, default=100)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
return args
|
return args
|
||||||
@ -141,8 +146,8 @@ def train():
|
|||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
|
scheduler_ginka = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_ginka, T_0=10, T_mult=2, eta_min=1e-6)
|
||||||
criterion_ginka = GinkaLoss(minamo)
|
criterion_ginka = GinkaLoss(minamo)
|
||||||
|
|
||||||
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-3)
|
optimizer_minamo = optim.AdamW(minamo.parameters(), lr=1e-4)
|
||||||
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=10, T_mult=2, eta_min=1e-6)
|
scheduler_minamo = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer_minamo, T_0=5, T_mult=2, eta_min=1e-6)
|
||||||
criterion_minamo = MinamoLoss()
|
criterion_minamo = MinamoLoss()
|
||||||
|
|
||||||
# 用于生成图片
|
# 用于生成图片
|
||||||
@ -238,6 +243,9 @@ def train():
|
|||||||
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
|
gen_list = np.concatenate((gen_list, map_matrix), axis=0)
|
||||||
|
|
||||||
tqdm.write(f"Cycle {cycle} Ginka train ended.")
|
tqdm.write(f"Cycle {cycle} Ginka train ended.")
|
||||||
|
torch.save({
|
||||||
|
"model_state": ginka.state_dict()
|
||||||
|
}, f"result/gan/ginka-{cycle}.pth")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": ginka.state_dict()
|
"model_state": ginka.state_dict()
|
||||||
}, f"result/ginka.pth")
|
}, f"result/ginka.pth")
|
||||||
@ -256,6 +264,17 @@ def train():
|
|||||||
conn.sendall(buf)
|
conn.sendall(buf)
|
||||||
data = parse_minamo_data(conn, prob_list)
|
data = parse_minamo_data(conn, prob_list)
|
||||||
minamo_dataset.set_data(data)
|
minamo_dataset.set_data(data)
|
||||||
|
vis_sim = 0
|
||||||
|
topo_sim = 0
|
||||||
|
for _, _, vis, topo, _ in data:
|
||||||
|
vis_sim += vis
|
||||||
|
topo_sim += topo
|
||||||
|
|
||||||
|
vis_sim /= len(data)
|
||||||
|
topo_sim /= len(data)
|
||||||
|
|
||||||
|
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(f'Cycle {cycle} | Ginka Vision Similarity: {vis_sim:.12f} | Ginka Topo Similarity: {topo_sim:.12f} | Ginka Loss: {avg_val_loss:.12f}')
|
||||||
|
|
||||||
# -------------------- 训练判别器
|
# -------------------- 训练判别器
|
||||||
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
|
for epoch in tqdm(range(EPOCHS_MINAMO), leave=False, desc="Training Minamo Model"):
|
||||||
@ -315,9 +334,14 @@ def train():
|
|||||||
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
|
}, f"result/minamo_checkpoint/{epoch + 1}.pth")
|
||||||
|
|
||||||
tqdm.write(f"Cycle {cycle} Minamo train ended.")
|
tqdm.write(f"Cycle {cycle} Minamo train ended.")
|
||||||
|
torch.save({
|
||||||
|
"model_state": minamo.state_dict()
|
||||||
|
}, f"result/gan/minamo-{cycle}.pth")
|
||||||
torch.save({
|
torch.save({
|
||||||
"model_state": minamo.state_dict()
|
"model_state": minamo.state_dict()
|
||||||
}, f"result/minamo.pth")
|
}, f"result/minamo.pth")
|
||||||
|
with open(LOSS_PATH, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(f' | Minamo: {avg_val_loss:.12f}\n')
|
||||||
|
|
||||||
print("Train ended.")
|
print("Train ended.")
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user