mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-20 17:22:41 +08:00
feat: 训练时的 heatmap
This commit is contained in:
parent
22a2db464f
commit
c000b90794
@ -95,7 +95,7 @@ def train():
|
||||
for file in os.listdir('tiles2'):
|
||||
name = os.path.splitext(file)[0]
|
||||
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED)
|
||||
|
||||
|
||||
# 接续训练
|
||||
if args.resume:
|
||||
data_ginka = torch.load(args.state_ginka, map_location=device)
|
||||
@ -105,7 +105,7 @@ def train():
|
||||
if args.load_optim:
|
||||
if data_ginka.get("optim_state") is not None:
|
||||
optimizer.load_state_dict(data_ginka["optim_state"])
|
||||
|
||||
|
||||
print("Train from loaded state.")
|
||||
|
||||
for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm):
|
||||
@ -113,9 +113,14 @@ def train():
|
||||
|
||||
for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm):
|
||||
target_map = batch["target_map"].to(device)
|
||||
cond = batch["val_cond"].to(device)
|
||||
cond = batch["cond"].to(device)
|
||||
heatmap = batch["heatmap"].to(device)
|
||||
B, H, W = target_map.shape
|
||||
|
||||
target_map = target_map.view(B, H * W)
|
||||
rand = torch.randn_like(heatmap).to(device) * 0.05
|
||||
if random.random() > 0.5:
|
||||
heatmap = heatmap + rand
|
||||
|
||||
mask = np.zeros((B, H * W))
|
||||
for i in range(B):
|
||||
@ -127,7 +132,7 @@ def train():
|
||||
masked_input = target_map.clone()
|
||||
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
|
||||
|
||||
logits = model(masked_input, cond)
|
||||
logits = model(masked_input, cond, heatmap)
|
||||
|
||||
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
|
||||
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
|
||||
@ -164,7 +169,8 @@ def train():
|
||||
for batch in tqdm(dataloader_val, desc="Validating generator.", leave=False, disable=disable_tqdm):
|
||||
# 1. 常规生成
|
||||
target_map = batch["target_map"].to(device)
|
||||
cond = batch["val_cond"].to(device)
|
||||
cond = batch["cond"].to(device)
|
||||
heatmap = batch["heatmap"].to(device)
|
||||
B, H, W = target_map.shape
|
||||
target_map = target_map.view(B, H * W)
|
||||
|
||||
@ -178,7 +184,7 @@ def train():
|
||||
masked_input = target_map.clone()
|
||||
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
|
||||
|
||||
logits = model(masked_input, cond)
|
||||
logits = model(masked_input, cond, heatmap)
|
||||
|
||||
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
|
||||
loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user