From 9814edd1b44c735c65dfa17910bf0c999207cdea Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 8 Apr 2026 19:42:18 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=95=B0=E6=8D=AE=E9=9B=86=E6=8A=A5?= =?UTF-8?q?=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ginka/dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ginka/dataset.py b/ginka/dataset.py index 3f2669a..4df738f 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -72,18 +72,18 @@ class GinkaMaskGITDataset(Dataset): sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1 heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0) + heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W] + for i in range(0, heatmap.shape[0]): if np.random.rand() < self.noise_prob: sigma = random.random() * self.sigma_rand - heatmap[i] = heatmap * sigma + np.random.randn() * (1 - sigma) + heatmap[i] = heatmap * sigma + torch.rand_like(heatmap[i]) * (1 - sigma) elif np.random.rand() < self.drop_prob: - heatmap[i] = np.zeros_like(heatmap[i]) - - heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W] + heatmap[i] = torch.zeros_like(heatmap[i]) if random.random() < 0.5: sigma = random.random() * self.sigma_rand - rand = torch.randn_like(heatmap) + rand = torch.rand_like(heatmap) heatmap = heatmap * (1 - sigma) + rand * sigma return {