diff --git a/ginka/dataset.py b/ginka/dataset.py index 4df738f..7d2d97c 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -18,7 +18,7 @@ def load_data(path: str): class GinkaMaskGITDataset(Dataset): def __init__( self, data_path: str, sigma_rand=0.1, blur_min=3, blur_max=6, - noise_prob=0.2, drop_prob=0.2 + noise_prob=0.2, drop_prob=0.2, noise_sigma=0.1 ): self.data = load_data(data_path) self.sigma_rand = sigma_rand @@ -26,6 +26,7 @@ class GinkaMaskGITDataset(Dataset): self.blur_max = blur_max self.noise_prob = noise_prob self.drop_prob = drop_prob + self.noise_sigma = noise_sigma def __len__(self): return len(self.data) @@ -52,8 +53,6 @@ class GinkaMaskGITDataset(Dataset): target_np = np.flipud(target_np) for i in range(0, heatmap.shape[0]): heatmap[i] = np.flipud(heatmap[i]) - - target = torch.LongTensor(target_np.copy()) # [H, W] cond = torch.FloatTensor(item['val']) # [cond_dim] @@ -76,8 +75,8 @@ class GinkaMaskGITDataset(Dataset): for i in range(0, heatmap.shape[0]): if np.random.rand() < self.noise_prob: - sigma = random.random() * self.sigma_rand - heatmap[i] = heatmap * sigma + torch.rand_like(heatmap[i]) * (1 - sigma) + sigma = random.random() * self.noise_sigma + heatmap[i] = heatmap[i] * sigma + torch.rand_like(heatmap[i]) * (1 - sigma) elif np.random.rand() < self.drop_prob: heatmap[i] = torch.zeros_like(heatmap[i]) diff --git a/ginka/maskGIT/model.py b/ginka/maskGIT/model.py index 12794c9..9092e1f 100644 --- a/ginka/maskGIT/model.py +++ b/ginka/maskGIT/model.py @@ -44,8 +44,9 @@ class GinkaMaskGIT(nn.Module): gate_input = torch.cat([heatmap_mean, heatmap_max], dim=1).squeeze(2).squeeze(2) gate = self.cond_gate(gate_input) # [B, d_model] + heatmap = heatmap * torch.sigmoid(gate).unsqueeze(2).unsqueeze(2) heatmap = heatmap.view(B, C, H * W).permute(0, 2, 1) - x = self.tile_embedding(map) + heatmap * torch.sigmoid(gate) + x = self.tile_embedding(map) + heatmap x = x + self.pos_embedding x = self.transformer(x) diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index 0bdee0a..f6938ee 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -46,7 +46,7 @@ HEATMAP_CHANNEL = 9 LABEL_SMOOTHING = 0 BLUR_MIN_SIZE = 3 BLUR_MAX_SIZE = 9 -RAND_RATIO = 0.15 +RAND_RATIO = 0.3 MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机 NUM_LAYERS = 4 D_MODEL = 128