feat: 训练时的 heatmap

This commit is contained in:
unanmed 2026-03-11 16:33:15 +08:00
parent 22a2db464f
commit c000b90794

View File

@ -95,7 +95,7 @@ def train():
for file in os.listdir('tiles2'):
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED)
# 接续训练
if args.resume:
data_ginka = torch.load(args.state_ginka, map_location=device)
@ -105,7 +105,7 @@ def train():
if args.load_optim:
if data_ginka.get("optim_state") is not None:
optimizer.load_state_dict(data_ginka["optim_state"])
print("Train from loaded state.")
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
@ -113,9 +113,14 @@ def train():
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
target_map = batch["target_map"].to(device)
cond = batch["val_cond"].to(device)
cond = batch["cond"].to(device)
heatmap = batch["heatmap"].to(device)
B, H, W = target_map.shape
target_map = target_map.view(B, H * W)
rand = torch.randn_like(heatmap).to(device) * 0.05
if random.random() > 0.5:
heatmap = heatmap + rand
mask = np.zeros((B, H * W))
for i in range(B):
@ -127,7 +132,7 @@ def train():
masked_input = target_map.clone()
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
logits = model(masked_input, cond)
logits = model(masked_input, cond, heatmap)
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
@ -164,7 +169,8 @@ def train():
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
# 1. 常规生成
target_map = batch["target_map"].to(device)
cond = batch["val_cond"].to(device)
cond = batch["cond"].to(device)
heatmap = batch["heatmap"].to(device)
B, H, W = target_map.shape
target_map = target_map.view(B, H * W)
@ -178,7 +184,7 @@ def train():
masked_input = target_map.clone()
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
logits = model(masked_input, cond)
logits = model(masked_input, cond, heatmap)
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6)