mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 数据集报错
This commit is contained in:
parent
dbb0b9064c
commit
9814edd1b4
@ -72,18 +72,18 @@ class GinkaMaskGITDataset(Dataset):
|
||||
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
|
||||
heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0)
|
||||
|
||||
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
|
||||
|
||||
for i in range(0, heatmap.shape[0]):
|
||||
if np.random.rand() < self.noise_prob:
|
||||
sigma = random.random() * self.sigma_rand
|
||||
heatmap[i] = heatmap * sigma + np.random.randn() * (1 - sigma)
|
||||
heatmap[i] = heatmap * sigma + torch.rand_like(heatmap[i]) * (1 - sigma)
|
||||
elif np.random.rand() < self.drop_prob:
|
||||
heatmap[i] = np.zeros_like(heatmap[i])
|
||||
|
||||
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
|
||||
heatmap[i] = torch.zeros_like(heatmap[i])
|
||||
|
||||
if random.random() < 0.5:
|
||||
sigma = random.random() * self.sigma_rand
|
||||
rand = torch.randn_like(heatmap)
|
||||
rand = torch.rand_like(heatmap)
|
||||
heatmap = heatmap * (1 - sigma) + rand * sigma
|
||||
|
||||
return {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user