fix: 数据集报错

This commit is contained in:
unanmed 2026-04-08 19:42:18 +08:00
parent dbb0b9064c
commit 9814edd1b4

View File

@ -72,18 +72,18 @@ class GinkaMaskGITDataset(Dataset):
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1 sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0) heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0)
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
for i in range(0, heatmap.shape[0]): for i in range(0, heatmap.shape[0]):
if np.random.rand() < self.noise_prob: if np.random.rand() < self.noise_prob:
sigma = random.random() * self.sigma_rand sigma = random.random() * self.sigma_rand
heatmap[i] = heatmap * sigma + np.random.randn() * (1 - sigma) heatmap[i] = heatmap * sigma + torch.rand_like(heatmap[i]) * (1 - sigma)
elif np.random.rand() < self.drop_prob: elif np.random.rand() < self.drop_prob:
heatmap[i] = np.zeros_like(heatmap[i]) heatmap[i] = torch.zeros_like(heatmap[i])
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
if random.random() < 0.5: if random.random() < 0.5:
sigma = random.random() * self.sigma_rand sigma = random.random() * self.sigma_rand
rand = torch.randn_like(heatmap) rand = torch.rand_like(heatmap)
heatmap = heatmap * (1 - sigma) + rand * sigma heatmap = heatmap * (1 - sigma) + rand * sigma
return { return {