diff --git a/ginka/dataset.py b/ginka/dataset.py index 3f2669a..4df738f 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -72,18 +72,18 @@ class GinkaMaskGITDataset(Dataset): sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1 heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0) + heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W] + for i in range(0, heatmap.shape[0]): if np.random.rand() < self.noise_prob: sigma = random.random() * self.sigma_rand - heatmap[i] = heatmap * sigma + np.random.randn() * (1 - sigma) + heatmap[i] = heatmap * sigma + torch.rand_like(heatmap[i]) * (1 - sigma) elif np.random.rand() < self.drop_prob: - heatmap[i] = np.zeros_like(heatmap[i]) - - heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W] + heatmap[i] = torch.zeros_like(heatmap[i]) if random.random() < 0.5: sigma = random.random() * self.sigma_rand - rand = torch.randn_like(heatmap) + rand = torch.rand_like(heatmap) heatmap = heatmap * (1 - sigma) + rand * sigma return {