fix: 报错

This commit is contained in:
unanmed 2026-03-30 12:34:18 +08:00
parent 020c2cf168
commit c64a783d5e
2 changed files with 4 additions and 4 deletions

View File

@ -30,7 +30,7 @@ class GinkaMaskGITDataset(Dataset):
target = torch.LongTensor(item['map']) # [H, W]
cond = torch.FloatTensor(item['val']) # [cond_dim]
heatmap = np.array(item['heatmap'])
heatmap = np.array(item['heatmap'], dtype=np.float32)
if random.random() < 0.5:
size = random.randint(self.blur_min, self.blur_max)

View File

@ -45,8 +45,8 @@ MAP_SIZE = 13 * 13
HEATMAP_CHANNEL = 9
LABEL_SMOOTHING = 0
BLUR_MIN_SIZE = 3
BLUR_MAX_SIZE = 6
RAND_RATIO = 0.1
BLUR_MAX_SIZE = 9
RAND_RATIO = 0.15
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
device = torch.device(
@ -77,7 +77,7 @@ def train():
args = parse_arguments()
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL).to(device)
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=3, d_model=128).to(device)
masker = MapMask([0.5, 0.5])
dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)