mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 数据集报错
This commit is contained in:
parent
dbb0b9064c
commit
9814edd1b4
@ -72,18 +72,18 @@ class GinkaMaskGITDataset(Dataset):
|
|||||||
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
|
sizeY = sizeY + 1 if random.random() < 0.5 else sizeY - 1
|
||||||
heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0)
|
heatmap = cv2.GaussianBlur(heatmap, (sizeX, sizeY), 0)
|
||||||
|
|
||||||
|
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
|
||||||
|
|
||||||
for i in range(0, heatmap.shape[0]):
|
for i in range(0, heatmap.shape[0]):
|
||||||
if np.random.rand() < self.noise_prob:
|
if np.random.rand() < self.noise_prob:
|
||||||
sigma = random.random() * self.sigma_rand
|
sigma = random.random() * self.sigma_rand
|
||||||
heatmap[i] = heatmap * sigma + np.random.randn() * (1 - sigma)
|
heatmap[i] = heatmap * sigma + torch.rand_like(heatmap[i]) * (1 - sigma)
|
||||||
elif np.random.rand() < self.drop_prob:
|
elif np.random.rand() < self.drop_prob:
|
||||||
heatmap[i] = np.zeros_like(heatmap[i])
|
heatmap[i] = torch.zeros_like(heatmap[i])
|
||||||
|
|
||||||
heatmap = torch.FloatTensor(heatmap) # [heatmap_channel, H, W]
|
|
||||||
|
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
sigma = random.random() * self.sigma_rand
|
sigma = random.random() * self.sigma_rand
|
||||||
rand = torch.randn_like(heatmap)
|
rand = torch.rand_like(heatmap)
|
||||||
heatmap = heatmap * (1 - sigma) + rand * sigma
|
heatmap = heatmap * (1 - sigma) + rand * sigma
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user