fix: 报错

This commit is contained in:
unanmed 2026-03-31 22:35:28 +08:00
parent 8de66d87f1
commit 36a4faff4e

View File

@ -35,17 +35,20 @@ class GinkaMaskGITDataset(Dataset):
if np.random.rand() > 0.5:
k = np.random.randint(0, 4)
target_np = np.rot90(target_np, k)
heatmap = np.rot90(heatmap, k)
for i in range(0, heatmap.shape[0]):
heatmap[i] = np.rot90(heatmap[i], k)
if np.random.rand() > 0.5:
target_np = np.fliplr(target_np)
heatmap = np.fliplr(heatmap)
for i in range(0, heatmap.shape[0]):
heatmap[i] = np.fliplr(heatmap[i])
if np.random.rand() > 0.5:
target_np = np.flipud(target_np)
heatmap = np.flipud(heatmap)
for i in range(0, heatmap.shape[0]):
heatmap[i] = np.flipud(heatmap[i])
target = torch.LongTensor(target_np) # [H, W]
target = torch.LongTensor(target_np.copy()) # [H, W]
cond = torch.FloatTensor(item['val']) # [cond_dim]
if random.random() < 0.5: