feat: 分块掩码方式

This commit is contained in:
unanmed 2026-03-11 14:23:46 +08:00
parent 8672c52ff5
commit 0cb22e9cf7
2 changed files with 96 additions and 14 deletions

View File

@ -16,6 +16,7 @@ from .vae_rnn.loss import VAELoss
from .vae_rnn.scheduler import VAEScheduler
from .dataset import GinkaRNNDataset
from shared.image import matrix_to_image_cv
from .transformer.mask import MapMask
# 手工标注标签定义(暂时不用):
# 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层,
@ -52,6 +53,7 @@ NUM_CLASSES = 16
MASK_TOKEN = 15
GENERATE_STEP = 8
MAP_SIZE = 13 * 13
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
device = torch.device(
"cuda:1" if torch.cuda.is_available()
@ -82,6 +84,7 @@ def train():
args = parse_arguments()
model = GinkaMaskGIT(num_classes=NUM_CLASSES).to(device)
masker = MapMask([0.5, 0.5])
dataset = GinkaRNNDataset(args.train, device)
dataset_val = GinkaRNNDataset(args.validate, device)
@ -98,6 +101,7 @@ def train():
name = os.path.splitext(file)[0]
tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED)
# 接续训练
if args.resume:
data_ginka = torch.load(args.state_ginka, map_location=device)
@ -118,19 +122,20 @@ def train():
B, H, W = target_map.shape
target_map = target_map.view(B, H * W)
# 1. 随机采样掩码比例 r (遵循余弦调度效果更好)
r = torch.rand(B).to(device)
r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本
mask = np.zeros((B, H * W))
for i in range(B):
mask[i] = masker.mask(H, W)
# 2. 生成掩码矩阵
masks = torch.rand(target_map.shape).to(device) < r
mask = torch.from_numpy(mask).to(torch.bool)
# 掩码
masked_input = target_map.clone()
masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
logits = model(masked_input, cond)
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=1)
loss = (loss * masks).sum() / (masks.sum() + 1e-6)
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
loss = (loss * mask).sum() / (mask.sum() + 1e-6)
optimizer.zero_grad()
loss.backward()
@ -168,19 +173,20 @@ def train():
B, H, W = target_map.shape
target_map = target_map.view(B, H * W)
# 1. 随机采样掩码比例 r (遵循余弦调度效果更好)
r = torch.rand(B).to(device)
r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本
mask = np.zeros((B, H * W))
for i in range(B):
mask[i] = masker.mask(H, W)
mask = torch.from_numpy(mask).to(torch.bool)
# 2. 生成掩码矩阵
masks = torch.rand(target_map.shape).to(device) < r
masked_input = target_map.clone()
masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记
masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记
logits = model(masked_input, cond)
loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1)
loss = (loss * masks.view(-1)).sum() / (masks.sum() + 1e-6)
loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6)
val_loss_total += loss.detach()

76
ginka/transformer/mask.py Normal file
View File

@ -0,0 +1,76 @@
import random
import torch
import numpy as np
from scipy.ndimage import binary_dilation, binary_erosion
class MapMask:
def __init__(self, probs: list[float] = [0.5, 0.5]):
# 掩码方案
# 0: 纯随机掩码
# 1: 分块随机掩码
self.probs = [sum(probs[0:i+1]) for i in range(len(probs))]
def _sample_mask_ratio(self, alpha=2, beta=2, min_ratio=0.05, max_ratio=1):
r = np.random.beta(alpha, beta)
r = min_ratio + (max_ratio - min_ratio) * r
return r
def mask(self, h: int, w: int):
test = random.random()
mask = None
if test < self.probs[0]:
mask = self.mask_random(h, w)
elif test < self.probs[1]:
mask = self.block_mask(h, w)
mask = self.random_morphology(mask)
return mask.reshape(h * w)
def mask_random(self, h: int, w: int):
# 纯随机掩码
ratio = self._sample_mask_ratio()
total = h * w
num = int(total * ratio)
idx = np.random.choice(total, num, replace=False)
mask = np.zeros(total, dtype=bool)
mask[idx] = True
return mask.reshape(h, w)
def block_mask(self, h: int, w: int, min_block=2, max_block=None):
# 分块随机掩码
ratio = self._sample_mask_ratio()
if max_block is None:
max_block = min(h, w) // 2
target = int(h * w * ratio)
mask = np.zeros((h, w), dtype=bool)
while mask.sum() < target:
bw = np.random.randint(min_block, max_block + 1)
bh = np.random.randint(min_block, max_block + 1)
x = np.random.randint(0, h - bh + 1)
y = np.random.randint(0, w - bw + 1)
mask[x:x + bh, y:y + bw] = True
return mask
def random_morphology(self, mask, max_iter=2):
op = np.random.choice(["none", "dilate", "erode"])
if op == "none":
return mask
it = np.random.randint(1, max_iter + 1)
if op == "dilate":
return binary_dilation(mask, iterations=it)
if op == "erode":
return binary_erosion(mask, iterations=it)