feat: 一轮对抗训练的 sh 脚本

This commit is contained in:
unanmed 2025-03-22 18:50:35 +08:00
parent ca068bbea3
commit a4167b59d6
2 changed files with 16 additions and 1 deletions

15
cycle.sh Normal file
View File

@ -0,0 +1,15 @@
# 训练部分
python3 -m minamo.train --epochs 30 --resume true
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
python3 -m minamo.train --epochs 10 --resume true
python3 -m ginka.train --epochs 10 --resume true
python3 -m ginka.validate
# 训练完毕,处理数据
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
cd data
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json"
cd ..

View File

@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel()
state = torch.load("result/ginka.pth", map_location=device)["model_state"]
state = torch.load("result/ginka_checkpoint/10.pth", map_location=device)["model_state"]
model.load_state_dict(state)
model.to(device)