mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 二三阶段训练
Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
parent
b8f691269d
commit
d4f0768d2f
@ -87,9 +87,20 @@ disable_tqdm = not sys.stdout.isatty()
|
||||
# ---------------------------------------------------------------------------
|
||||
# 参数解析
|
||||
# ---------------------------------------------------------------------------
|
||||
def _str2bool(v: str) -> bool:
|
||||
"""argparse 专用:将字符串 'True'/'False' 正确转为 bool。
|
||||
type=bool 会把任何非空字符串(包括 'False')解析为 True,故需此辅助。"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ('true', '1', 'yes'):
|
||||
return True
|
||||
if v.lower() in ('false', '0', 'no'):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}")
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="三通道分拆 VQ 编码器预训练(方案 B)")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--resume", type=_str2bool, default=False)
|
||||
parser.add_argument("--state", type=str, default="result/pretrain_split/split-10.pth",
|
||||
help="续训时加载的检查点路径")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
@ -97,7 +108,7 @@ def parse_arguments():
|
||||
parser.add_argument("--epochs", type=int, default=60)
|
||||
parser.add_argument("--checkpoint", type=int, default=5,
|
||||
help="每隔多少 epoch 保存检查点并输出验证指标")
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
parser.add_argument("--load_optim", type=_str2bool, default=True)
|
||||
return parser.parse_args()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@ -99,9 +99,20 @@ disable_tqdm = not sys.stdout.isatty()
|
||||
# ---------------------------------------------------------------------------
|
||||
# 参数解析
|
||||
# ---------------------------------------------------------------------------
|
||||
def _str2bool(v: str) -> bool:
|
||||
"""argparse 专用:将字符串 'True'/'False' 正确转为 bool。
|
||||
type=bool 会把任何非空字符串(包括 'False')解析为 True,故需此辅助。"""
|
||||
if isinstance(v, bool):
|
||||
return v
|
||||
if v.lower() in ('true', '1', 'yes'):
|
||||
return True
|
||||
if v.lower() in ('false', '0', 'no'):
|
||||
return False
|
||||
raise argparse.ArgumentTypeError(f"布尔值应为 True/False,收到: {v!r}")
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description="VQ-VAE + MaskGIT 联合训练")
|
||||
parser.add_argument("--resume", type=bool, default=False)
|
||||
parser.add_argument("--resume", type=_str2bool, default=False)
|
||||
parser.add_argument("--state", type=str, default="result/joint/joint-10.pth",
|
||||
help="续训时加载的检查点路径")
|
||||
parser.add_argument("--train", type=str, default="ginka-dataset.json")
|
||||
@ -109,8 +120,8 @@ def parse_arguments():
|
||||
parser.add_argument("--epochs", type=int, default=100)
|
||||
parser.add_argument("--checkpoint", type=int, default=5,
|
||||
help="每隔多少 epoch 保存检查点并验证")
|
||||
parser.add_argument("--load_optim", type=bool, default=True)
|
||||
parser.add_argument("--freeze_vq", type=bool, default=False,
|
||||
parser.add_argument("--load_optim", type=_str2bool, default=True)
|
||||
parser.add_argument("--freeze_vq", type=_str2bool, default=False,
|
||||
help="(方案 B 阶段 1)冻结三路 VQ 编码器,仅训练 MaskGIT。"
|
||||
"适用于预训练权重加载后的热身阶段。")
|
||||
parser.add_argument("--pretrain_split", type=str, default="",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user