From 90cfe54bd2a2c0141e8f446e39f9184a521da509 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 23 Apr 2026 18:24:18 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E8=81=94=E5=90=88=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_heatmap.py | 2 +- ginka/train_joint.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ginka/train_heatmap.py b/ginka/train_heatmap.py index 546688d..5c0d836 100644 --- a/ginka/train_heatmap.py +++ b/ginka/train_heatmap.py @@ -219,7 +219,7 @@ def train(): print("Train ended.") torch.save({ - "model_state": maskGIT.state_dict(), + "model_state": model.state_dict(), }, f"result/ginka_heatmap.pth") def get_nms_sampling_count(): diff --git a/ginka/train_joint.py b/ginka/train_joint.py index 06e495e..68cbc57 100644 --- a/ginka/train_joint.py +++ b/ginka/train_joint.py @@ -17,7 +17,7 @@ from .heatmap.model import GinkaHeatmapModel from .maskGIT.model import GinkaMaskGIT -BATCH_SIZE = 128 +BATCH_SIZE = 64 VAL_BATCH_DIVIDER = 64 NUM_CLASSES = 16 MASK_TOKEN = 15 @@ -33,6 +33,8 @@ D_MODEL_DIFFUSION = 128 T_DIFFUSION = 100 MIN_MASK = 0 MAX_MASK = 1 +CE_WEIGHT = 0.5 +DROP_RATE = 0.2 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -48,15 +50,13 @@ disable_tqdm = not sys.stdout.isatty() def parse_arguments(): parser = argparse.ArgumentParser(description="joint training codes") parser.add_argument("--resume", type=bool, default=False) - parser.add_argument("--state_heatmap", type=str, default="result/heatmap/ginka-100.pth") + parser.add_argument("--state_heatmap", type=str, default="result/ginka_heatmap.pth") parser.add_argument("--train", type=str, default="ginka-dataset.json") parser.add_argument("--validate", type=str, default="ginka-eval.json") parser.add_argument("--epochs", type=int, default=50) parser.add_argument("--checkpoint", type=int, default=5) parser.add_argument("--load_optim", type=bool, default=True) parser.add_argument("--maskgit_path", type=str, default="result/ginka_transformer.pth") - parser.add_argument("--ce_weight", type=float, default=1.0) - parser.add_argument("--cfg_drop_rate", type=float, default=0.2) args = parser.parse_args() return args @@ -231,7 +231,7 @@ def train(): cond_for_diffusion = cond_heatmap use_unconditional_branch = False - if np.random.rand() < args.cfg_drop_rate: + if np.random.rand() < DROP_RATE: cond_for_diffusion = torch.zeros_like(cond_heatmap) use_unconditional_branch = True @@ -245,7 +245,7 @@ def train(): generated_heatmap = predict_x0(diffusion, x_t, pred_noise_for_joint, t) maskgit_loss = maskgit_joint_loss(maskgit, generated_heatmap, target_map) - loss = diffusion_loss + args.ce_weight * maskgit_loss + loss = diffusion_loss + CE_WEIGHT * maskgit_loss loss.backward() optimizer.step() @@ -275,7 +275,7 @@ def train(): checkpoint_path, ) - metrics = validate(model, maskgit, diffusion, dataloader_val, args.ce_weight) + metrics = validate(model, maskgit, diffusion, dataloader_val, CE_WEIGHT) tqdm.write( f"[Validate {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] " f"E: {epoch + 1} | "