From df23c891c6d71f3b0bd190d088de6a72bf516ae9 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Thu, 9 Apr 2026 12:47:18 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20MaskGIT=20=E8=AE=AD=E7=BB=83=E8=B0=83?= =?UTF-8?q?=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/dataset.py | 9 ++++----- ginka/maskGIT/model.py | 3 ++- ginka/train_maskGIT.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ginka/dataset.py b/ginka/dataset.py index 4df738f..7d2d97c 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -18,7 +18,7 @@ def load_data(path: str): class GinkaMaskGITDataset(Dataset): def __init__( self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6, - noise_prob=0.2, drop_prob=0.2 + noise_prob=0.2, drop_prob=0.2, noise_sigma=0.1 ): self.data = load_data(data_path) self.sigma_rand = sigma_rand @@ -26,6 +26,7 @@ class GinkaMaskGITDataset(Dataset): self.blur_max = blur_max self.noise_prob = noise_prob self.drop_prob = drop_prob + self.noise_sigma = noise_sigma def __len__(self): return len(self.data) @@ -52,8 +53,6 @@ class GinkaMaskGITDataset(Dataset): target_np = np.flipud(target_np) for i in range(0, heatmap.shape[0]): heatmap[i] = np.flipud(heatmap[i]) - - target = torch.LongTensor(target_np.copy()) # [H, W] cond = torch.FloatTensor(item['val']) # [cond_dim] @@ -76,8 +75,8 @@ class GinkaMaskGITDataset(Dataset): for i in range(0, heatmap.shape[0]): if np.random.rand() < self.noise_prob: - sigma = random.random() * self.sigma_rand - heatmap[i] = heatmap * sigma + torch.rand_like(heatmap[i]) * (1 - sigma) + sigma = random.random() * self.noise_sigma + heatmap[i] = heatmap[i] * sigma + torch.rand_like(heatmap[i]) * (1 - sigma) elif np.random.rand() < self.drop_prob: heatmap[i] = torch.zeros_like(heatmap[i]) diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index 12794c9..9092e1f 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -44,8 +44,9 @@ class GinkaMaskGIT(nn.Module): gate_input = torch.cat([heatmap_mean, heatmap_max], dim=1).squeeze(2).squeeze(2) gate = self.cond_gate(gate_input) # [B, d_model] + heatmap = heatmap * torch.sigmoid(gate).unsqueeze(2).unsqueeze(2) heatmap = heatmap.view(B, C, H * W).permute(0, 2, 1) - x = self.tile_embedding(map) + heatmap * torch.sigmoid(gate) + x = self.tile_embedding(map) + heatmap x = x + self.pos_embedding x = self.transformer(x) diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 0bdee0a..f6938ee 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -46,7 +46,7 @@ HEATMAP_CHANNEL = 9 LABEL_SMOOTHING = 0 BLUR_MIN_SIZE = 3 BLUR_MAX_SIZE = 9 -RAND_RATIO = 0.15 +RAND_RATIO = 0.3 MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机 NUM_LAYERS = 4 D_MODEL = 128