From d4f0768d2f43941cf2c22c4fc7821effade69378 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Tue, 5 May 2026 23:50:01 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E4=BA=8C=E4=B8=89=E9=98=B6=E6=AE=B5?= =?UTF-8?q?=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Copilot --- ginka/train_pretrain_split.py | 15 +++++++++++++-- ginka/train_vq.py | 17 ++++++++++++++--- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/ginka/train_pretrain_split.py b/ginka/train_pretrain_split.py index 5e68c5c..c7e055c 100644 --- a/ginka/train_pretrain_split.py +++ b/ginka/train_pretrain_split.py @@ -87,9 +87,20 @@ disable_tqdm = not sys.stdout.isatty() # --------------------------------------------------------------------------- # 参数解析 # --------------------------------------------------------------------------- +def _str2bool(v: str) -> bool: + """argparse 专用:将字符串 'True'/'False' 正确转为 bool。 + type=bool 会把任何非空字符串(包括 'False')解析为 True,故需此辅助。""" + if isinstance(v, bool): + return v + if v.lower() in ('true', '1', 'yes'): + return True + if v.lower() in ('false', '0', 'no'): + return False + raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}") + def parse_arguments(): parser = argparse.ArgumentParser(description="三通道分拆 VQ 编码器预训练(方案 B)") - parser.add_argument("--resume", type=bool, default=False) + parser.add_argument("--resume", type=_str2bool, default=False) parser.add_argument("--state", type=str, default="result/pretrain_split/split-10.pth", help="续训时加载的检查点路径") parser.add_argument("--train", type=str, default="ginka-dataset.json") @@ -97,7 +108,7 @@ def parse_arguments(): parser.add_argument("--epochs", type=int, default=60) parser.add_argument("--checkpoint", type=int, default=5, help="每隔多少 epoch 保存检查点并输出验证指标") - parser.add_argument("--load_optim", type=bool, default=True) + parser.add_argument("--load_optim", type=_str2bool, default=True) return parser.parse_args() # --------------------------------------------------------------------------- diff --git a/ginka/train_vq.py b/ginka/train_vq.py index 29bf6e4..3d16f39 100644 --- a/ginka/train_vq.py +++ b/ginka/train_vq.py @@ -99,9 +99,20 @@ disable_tqdm = not sys.stdout.isatty() # --------------------------------------------------------------------------- # 参数解析 # --------------------------------------------------------------------------- +def _str2bool(v: str) -> bool: + """argparse 专用:将字符串 'True'/'False' 正确转为 bool。 + type=bool 会把任何非空字符串(包括 'False')解析为 True,故需此辅助。""" + if isinstance(v, bool): + return v + if v.lower() in ('true', '1', 'yes'): + return True + if v.lower() in ('false', '0', 'no'): + return False + raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}") + def parse_arguments(): parser = argparse.ArgumentParser(description="VQ-VAE + MaskGIT 联合训练") - parser.add_argument("--resume", type=bool, default=False) + parser.add_argument("--resume", type=_str2bool, default=False) parser.add_argument("--state", type=str, default="result/joint/joint-10.pth", help="续训时加载的检查点路径") parser.add_argument("--train", type=str, default="ginka-dataset.json") @@ -109,8 +120,8 @@ def parse_arguments(): parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--checkpoint", type=int, default=5, help="每隔多少 epoch 保存检查点并验证") - parser.add_argument("--load_optim", type=bool, default=True) - parser.add_argument("--freeze_vq", type=bool, default=False, + parser.add_argument("--load_optim", type=_str2bool, default=True) + parser.add_argument("--freeze_vq", type=_str2bool, default=False, help="(方案 B 阶段 1)冻结三路 VQ 编码器,仅训练 MaskGIT。" "适用于预训练权重加载后的热身阶段。") parser.add_argument("--pretrain_split", type=str, default="",