From c000b90794abee6872c947793e7c99d260692327 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 11 Mar 2026 16:33:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E8=AE=AD=E7=BB=83=E6=97=B6=E7=9A=84=20?= =?UTF-8?q?heatmap?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_maskGIT.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index c9f6a13..bfb8bb4 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -95,7 +95,7 @@ def train(): for file in os.listdir('tiles2'): name = os.path.splitext(file)[0] tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED) - + # 接续训练 if args.resume: data_ginka = torch.load(args.state_ginka, map_location=device) @@ -105,7 +105,7 @@ def train(): if args.load_optim: if data_ginka.get("optim_state") is not None: optimizer.load_state_dict(data_ginka["optim_state"]) - + print("Train from loaded state.") for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm): @@ -113,9 +113,14 @@ def train(): for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): target_map = batch["target_map"].to(device) - cond = batch["val_cond"].to(device) + cond = batch["cond"].to(device) + heatmap = batch["heatmap"].to(device) B, H, W = target_map.shape + target_map = target_map.view(B, H * W) + rand = torch.randn_like(heatmap).to(device) * 0.05 + if random.random() > 0.5: + heatmap = heatmap + rand mask = np.zeros((B, H * W)) for i in range(B): @@ -127,7 +132,7 @@ def train(): masked_input = target_map.clone() masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记 - logits = model(masked_input, cond) + logits = model(masked_input, cond, heatmap) loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1) loss = (loss * mask).sum() / (mask.sum() + 1e-6) @@ -164,7 +169,8 @@ def train(): for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm): # 1. 常规生成 target_map = batch["target_map"].to(device) - cond = batch["val_cond"].to(device) + cond = batch["cond"].to(device) + heatmap = batch["heatmap"].to(device) B, H, W = target_map.shape target_map = target_map.view(B, H * W) @@ -178,7 +184,7 @@ def train(): masked_input = target_map.clone() masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记 - logits = model(masked_input, cond) + logits = model(masked_input, cond, heatmap) loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1) loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6)