mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 14:31:11 +08:00
fix: 报错
This commit is contained in:
parent
020c2cf168
commit
c64a783d5e
@ -30,7 +30,7 @@ class GinkaMaskGITDataset(Dataset):
|
||||
|
||||
target = torch.LongTensor(item['map']) # [H, W]
|
||||
cond = torch.FloatTensor(item['val']) # [cond_dim]
|
||||
heatmap = np.array(item['heatmap'])
|
||||
heatmap = np.array(item['heatmap'], dtype=np.float32)
|
||||
|
||||
if random.random() < 0.5:
|
||||
size = random.randint(self.blur_min, self.blur_max)
|
||||
|
||||
@ -45,8 +45,8 @@ MAP_SIZE = 13 * 13
|
||||
HEATMAP_CHANNEL = 9
|
||||
LABEL_SMOOTHING = 0
|
||||
BLUR_MIN_SIZE = 3
|
||||
BLUR_MAX_SIZE = 6
|
||||
RAND_RATIO = 0.1
|
||||
BLUR_MAX_SIZE = 9
|
||||
RAND_RATIO = 0.15
|
||||
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
||||
|
||||
device = torch.device(
|
||||
@ -77,7 +77,7 @@ def train():
|
||||
|
||||
args = parse_arguments()
|
||||
|
||||
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL).to(device)
|
||||
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=3, d_model=128).to(device)
|
||||
masker = MapMask([0.5, 0.5])
|
||||
|
||||
dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user