fix: 二三阶段训练

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
unanmed 2026-05-05 23:50:01 +08:00
parent b8f691269d
commit d4f0768d2f
2 changed files with 27 additions and 5 deletions

View File

@ -87,9 +87,20 @@ disable_tqdm = not sys.stdout.isatty()
# ---------------------------------------------------------------------------
# 参数解析
# ---------------------------------------------------------------------------
def _str2bool(v: str) -> bool:
"""argparse 专用:将字符串 'True'/'False' 正确转为 bool。
type=bool 会把任何非空字符串包括 'False'解析为 True故需此辅助"""
if isinstance(v, bool):
return v
if v.lower() in ('true', '1', 'yes'):
return True
if v.lower() in ('false', '0', 'no'):
return False
raise argparse.ArgumentTypeError(f"布尔值应为 True/False收到: {v!r}")
def parse_arguments():
parser = argparse.ArgumentParser(description="三通道分拆 VQ 编码器预训练(方案 B")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--resume", type=_str2bool, default=False)
parser.add_argument("--state", type=str, default="result/pretrain_split/split-10.pth",
help="续训时加载的检查点路径")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
@ -97,7 +108,7 @@ def parse_arguments():
parser.add_argument("--epochs", type=int, default=60)
parser.add_argument("--checkpoint", type=int, default=5,
help="每隔多少 epoch 保存检查点并输出验证指标")
parser.add_argument("--load_optim", type=bool, default=True)
parser.add_argument("--load_optim", type=_str2bool, default=True)
return parser.parse_args()
# ---------------------------------------------------------------------------

View File

@ -99,9 +99,20 @@ disable_tqdm = not sys.stdout.isatty()
# ---------------------------------------------------------------------------
# 参数解析
# ---------------------------------------------------------------------------
def _str2bool(v: str) -> bool:
"""argparse 专用:将字符串 'True'/'False' 正确转为 bool。
type=bool 会把任何非空字符串包括 'False'解析为 True故需此辅助"""
if isinstance(v, bool):
return v
if v.lower() in ('true', '1', 'yes'):
return True
if v.lower() in ('false', '0', 'no'):
return False
raise argparse.ArgumentTypeError(f"布尔值应为 True/False收到: {v!r}")
def parse_arguments():
parser = argparse.ArgumentParser(description="VQ-VAE + MaskGIT 联合训练")
parser.add_argument("--resume", type=bool, default=False)
parser.add_argument("--resume", type=_str2bool, default=False)
parser.add_argument("--state", type=str, default="result/joint/joint-10.pth",
help="续训时加载的检查点路径")
parser.add_argument("--train", type=str, default="ginka-dataset.json")
@ -109,8 +120,8 @@ def parse_arguments():
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--checkpoint", type=int, default=5,
help="每隔多少 epoch 保存检查点并验证")
parser.add_argument("--load_optim", type=bool, default=True)
parser.add_argument("--freeze_vq", type=bool, default=False,
parser.add_argument("--load_optim", type=_str2bool, default=True)
parser.add_argument("--freeze_vq", type=_str2bool, default=False,
help="(方案 B 阶段 1冻结三路 VQ 编码器,仅训练 MaskGIT。"
"适用于预训练权重加载后的热身阶段。")
parser.add_argument("--pretrain_split", type=str, default="",