From 954c0109b9f38cc98ff33f4d40f7f4c358187f51 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 11 Mar 2026 12:29:33 +0800 Subject: [PATCH] fix: mask --- ginka/train_maskGIT.py | 55 +++++++++++++++++++++-------------------- tiles2/15.png | Bin 0 -> 414 bytes 2 files changed, 28 insertions(+), 27 deletions(-) create mode 100644 tiles2/15.png diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 4abcc68..2f3049d 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -51,6 +51,7 @@ VAL_BATCH_DIVIDER = 16 NUM_CLASSES = 16 MASK_TOKEN = 15 GENERATE_STEP = 8 +MAP_SIZE = 13 * 13 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -111,30 +112,30 @@ def train(): for epoch in tqdm(range(args.epochs), desc="VAE Training", disable=disable_tqdm): loss_total = torch.Tensor([0]).to(device) - for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): - target_map = batch["target_map"].to(device) - cond = batch["val_cond"].to(device) - B, H, W = target_map.shape - target_map = target_map.view(B, H * W) + # for batch in tqdm(dataloader, leave=False, desc="Epoch Progress", disable=disable_tqdm): + # target_map = batch["target_map"].to(device) + # cond = batch["val_cond"].to(device) + # B, H, W = target_map.shape + # target_map = target_map.view(B, H * W) - # 1. 随机采样掩码比例 r (遵循余弦调度效果更好) - r = torch.rand(B).to(device) - r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本 + # # 1. 随机采样掩码比例 r (遵循余弦调度效果更好) + # r = torch.rand(B).to(device) + # r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本 - # 2. 生成掩码矩阵 - masks = torch.rand(target_map.shape).to(device) < r - masked_input = target_map.clone() - masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记 + # # 2. 生成掩码矩阵 + # masks = torch.rand(target_map.shape).to(device) < r + # masked_input = target_map.clone() + # masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记 - logits = model(masked_input, cond) + # logits = model(masked_input, cond) - loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none') - loss = (loss * masks).sum() / (masks.sum() + 1e-6) + # loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=1) + # loss = (loss * masks).sum() / (masks.sum() + 1e-6) - optimizer.zero_grad() - loss.backward() - optimizer.step() - loss_total += loss.detach() + # optimizer.zero_grad() + # loss.backward() + # optimizer.step() + # loss_total += loss.detach() scheduler.step() @@ -178,7 +179,7 @@ def train(): logits = model(masked_input, cond) - loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none') + loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1) loss = (loss * masks.view(-1)).sum() / (masks.sum() + 1e-6) val_loss_total += loss.detach() @@ -193,28 +194,28 @@ def train(): idx += 1 # 2. 从头完整生成 - map = torch.full((1, 169), MASK_TOKEN).to(device) - for _ in range(GENERATE_STEP): + map = torch.full((1, MAP_SIZE), MASK_TOKEN).to(device) + for i in range(GENERATE_STEP): # 1. 预测 - logits = model(map, cond) + logits = model(map, cond) # [1, H * W, num_classes] probs = F.softmax(logits, dim=-1) # 2. 采样(为了多样性,这里可以使用概率采样而不是取最大值) dist = torch.distributions.Categorical(probs) - sampled_tiles = dist.sample() # (1, 169) + sampled_tiles = dist.sample() # [1, H * W] # 3. 计算置信度 (模型对采样结果的信心程度) confidences = torch.gather(probs, -1, sampled_tiles.unsqueeze(-1)).squeeze(-1) # 4. 决定本轮要固定多少个格子 (上凸函数逻辑) - ratio = math.cos((GENERATE_STEP) * math.pi / 2) - num_to_mask = int(ratio * 169) + ratio = math.cos(((i + 1) / GENERATE_STEP) * math.pi / 2) + num_to_mask = math.floor(ratio * MAP_SIZE) # 5. 更新画布:保留置信度最高的部分,其余位置设回 MASK # 注意:这里逻辑上通常是保留当前步预测中置信度最高的,并结合已有的非 mask 部分 if num_to_mask > 0: _, mask_indices = torch.topk(confidences, k=num_to_mask, largest=False) - sampled_tiles.scatter_(1, mask_indices, MASK_TOKEN) + sampled_tiles = sampled_tiles.scatter(1, mask_indices, MASK_TOKEN) map = sampled_tiles if (map == MASK_TOKEN).sum() == 0: diff --git a/tiles2/15.png b/tiles2/15.png new file mode 100644 index 0000000000000000000000000000000000000000..eb627852c81f06a4f7aa982b763d620b37d3e28d GIT binary patch literal 414 zcmV;P0b%}$P)Px$SV=@dR9HvtmcbE%Fc3vU37oVdTEGih2`6d9@uDrzkta*QS(xv97aTSrnOw+4 zM=<~Kck`FRWnM*QeYXPyJMg%$#00c(X`_+!F_&}Hev^WIi696F)CV+q9+ye|D&d9Pju?hfj)LDdJHR>-Q z@|~~Y{xt{_K;2BD+-wkJ+i@BgEE<3oDvuNcLpO6p?zc3g*#SJg)VDtWifcHZW4R4r z=)PsqeUaV)63^p6R)A(oK3+*nz(-EUFDD}MAbJz(kkJ{S^C1ZU+pV7z9iJ?Scw=Nf zkK!fkqeZc7#PdLpUb)Wi1aKqTFv64QiHNS3>r1m@Z3bZ9)g{<;Hv0DMvvFkP1gYXATM07*qo IM6N<$f_^NkHvj+t literal 0 HcmV?d00001