diff --git a/ginka/dataset.py b/ginka/dataset.py index 95b7f95..2450b6d 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -30,7 +30,7 @@ class GinkaMaskGITDataset(Dataset): target = torch.LongTensor(item['map']) # [H, W] cond = torch.FloatTensor(item['val']) # [cond_dim] - heatmap = np.array(item['heatmap']) + heatmap = np.array(item['heatmap'], dtype=np.float32) if random.random() < 0.5: size = random.randint(self.blur_min, self.blur_max) diff --git a/ginka/train_maskGIT.py b/ginka/train_maskGIT.py index f1fbc5d..024d03d 100644 --- a/ginka/train_maskGIT.py +++ b/ginka/train_maskGIT.py @@ -45,8 +45,8 @@ MAP_SIZE = 13 * 13 HEATMAP_CHANNEL = 9 LABEL_SMOOTHING = 0 BLUR_MIN_SIZE = 3 -BLUR_MAX_SIZE = 6 -RAND_RATIO = 0.1 +BLUR_MAX_SIZE = 9 +RAND_RATIO = 0.15 MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机 device = torch.device( @@ -77,7 +77,7 @@ def train(): args = parse_arguments() - model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL).to(device) + model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=3, d_model=128).to(device) masker = MapMask([0.5, 0.5]) dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)