From 0cb22e9cf796f091374195e4f473d9fedec7fbfd Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 11 Mar 2026 14:23:46 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=86=E5=9D=97=E6=8E=A9=E7=A0=81?= =?UTF-8?q?=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/train_maskGIT.py | 34 ++++++++++-------- ginka/transformer/mask.py | 76 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 14 deletions(-) create mode 100644 ginka/transformer/mask.py diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index e3b2473..5b8d1ea 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -16,6 +16,7 @@ from .vae_rnn.loss import VAELoss from .vae_rnn.scheduler import VAEScheduler from .dataset import GinkaRNNDataset from shared.image import matrix_to_image_cv +from .transformer.mask import MapMask # 手工标注标签定义(暂时不用): # 0. 蓝海, 1. 红海, 2: 室内, 3. 野外, 4. 左右对称, 5. 上下对称, 6. 伪对称, 7. 咸鱼层, @@ -52,6 +53,7 @@ NUM_CLASSES = 16 MASK_TOKEN = 15 GENERATE_STEP = 8 MAP_SIZE = 13 * 13 +MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机 device = torch.device( "cuda:1" if torch.cuda.is_available() @@ -82,6 +84,7 @@ def train(): args = parse_arguments() model = GinkaMaskGIT(num_classes=NUM_CLASSES).to(device) + masker = MapMask([0.5, 0.5]) dataset = GinkaRNNDataset(args.train, device) dataset_val = GinkaRNNDataset(args.validate, device) @@ -98,6 +101,7 @@ def train(): name = os.path.splitext(file)[0] tile_dict[name] = cv2.imread(f"tiles2/{file}", cv2.IMREAD_UNCHANGED) + # 接续训练 if args.resume: data_ginka = torch.load(args.state_ginka, map_location=device) @@ -118,19 +122,20 @@ def train(): B, H, W = target_map.shape target_map = target_map.view(B, H * W) - # 1. 随机采样掩码比例 r (遵循余弦调度效果更好) - r = torch.rand(B).to(device) - r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本 + mask = np.zeros((B, H * W)) + for i in range(B): + mask[i] = masker.mask(H, W) - # 2. 生成掩码矩阵 - masks = torch.rand(target_map.shape).to(device) < r + mask = torch.from_numpy(mask).to(torch.bool) + + # 掩码 masked_input = target_map.clone() - masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记 + masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记 logits = model(masked_input, cond) - loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=1) - loss = (loss * masks).sum() / (masks.sum() + 1e-6) + loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1) + loss = (loss * mask).sum() / (mask.sum() + 1e-6) optimizer.zero_grad() loss.backward() @@ -168,19 +173,20 @@ def train(): B, H, W = target_map.shape target_map = target_map.view(B, H * W) - # 1. 随机采样掩码比例 r (遵循余弦调度效果更好) - r = torch.rand(B).to(device) - r = torch.cos(r * math.pi / 2).unsqueeze(1) # 产生更多高掩码比例的样本 + mask = np.zeros((B, H * W)) + for i in range(B): + mask[i] = masker.mask(H, W) + + mask = torch.from_numpy(mask).to(torch.bool) # 2. 生成掩码矩阵 - masks = torch.rand(target_map.shape).to(device) < r masked_input = target_map.clone() - masked_input[masks] = MASK_TOKEN # 填充为 [MASK] 标记 + masked_input[mask] = MASK_TOKEN # 填充为 [MASK] 标记 logits = model(masked_input, cond) loss = F.cross_entropy(logits.permute(0, 2, 1), target_map, reduction='none', label_smoothing=0.1) - loss = (loss * masks.view(-1)).sum() / (masks.sum() + 1e-6) + loss = (loss * mask.view(-1)).sum() / (mask.sum() + 1e-6) val_loss_total += loss.detach() diff --git a/ginka/transformer/mask.py b/ginka/transformer/mask.py new file mode 100644 index 0000000..0cfb31a --- /dev/null +++ b/ginka/transformer/mask.py @@ -0,0 +1,76 @@ +import random +import torch +import numpy as np +from scipy.ndimage import binary_dilation, binary_erosion + +class MapMask: + def __init__(self, probs: list[float] = [0.5, 0.5]): + # 掩码方案 + # 0: 纯随机掩码 + # 1: 分块随机掩码 + self.probs = [sum(probs[0:i+1]) for i in range(len(probs))] + + def _sample_mask_ratio(self, alpha=2, beta=2, min_ratio=0.05, max_ratio=1): + r = np.random.beta(alpha, beta) + r = min_ratio + (max_ratio - min_ratio) * r + return r + + def mask(self, h: int, w: int): + test = random.random() + mask = None + if test < self.probs[0]: + mask = self.mask_random(h, w) + elif test < self.probs[1]: + mask = self.block_mask(h, w) + + mask = self.random_morphology(mask) + return mask.reshape(h * w) + + def mask_random(self, h: int, w: int): + # 纯随机掩码 + ratio = self._sample_mask_ratio() + total = h * w + num = int(total * ratio) + + idx = np.random.choice(total, num, replace=False) + + mask = np.zeros(total, dtype=bool) + mask[idx] = True + + return mask.reshape(h, w) + + def block_mask(self, h: int, w: int, min_block=2, max_block=None): + # 分块随机掩码 + ratio = self._sample_mask_ratio() + if max_block is None: + max_block = min(h, w) // 2 + + target = int(h * w * ratio) + mask = np.zeros((h, w), dtype=bool) + + while mask.sum() < target: + + bw = np.random.randint(min_block, max_block + 1) + bh = np.random.randint(min_block, max_block + 1) + + x = np.random.randint(0, h - bh + 1) + y = np.random.randint(0, w - bw + 1) + + mask[x:x + bh, y:y + bw] = True + + return mask + + def random_morphology(self, mask, max_iter=2): + op = np.random.choice(["none", "dilate", "erode"]) + + if op == "none": + return mask + + it = np.random.randint(1, max_iter + 1) + + if op == "dilate": + return binary_dilation(mask, iterations=it) + + if op == "erode": + return binary_erosion(mask, iterations=it) + \ No newline at end of file