mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 10:21:15 +08:00
fix: 报错
This commit is contained in:
parent
020c2cf168
commit
c64a783d5e
@ -30,7 +30,7 @@ class GinkaMaskGITDataset(Dataset):
|
|||||||
|
|
||||||
target = torch.LongTensor(item['map']) # [H, W]
|
target = torch.LongTensor(item['map']) # [H, W]
|
||||||
cond = torch.FloatTensor(item['val']) # [cond_dim]
|
cond = torch.FloatTensor(item['val']) # [cond_dim]
|
||||||
heatmap = np.array(item['heatmap'])
|
heatmap = np.array(item['heatmap'], dtype=np.float32)
|
||||||
|
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
size = random.randint(self.blur_min, self.blur_max)
|
size = random.randint(self.blur_min, self.blur_max)
|
||||||
|
|||||||
@ -45,8 +45,8 @@ MAP_SIZE = 13 * 13
|
|||||||
HEATMAP_CHANNEL = 9
|
HEATMAP_CHANNEL = 9
|
||||||
LABEL_SMOOTHING = 0
|
LABEL_SMOOTHING = 0
|
||||||
BLUR_MIN_SIZE = 3
|
BLUR_MIN_SIZE = 3
|
||||||
BLUR_MAX_SIZE = 6
|
BLUR_MAX_SIZE = 9
|
||||||
RAND_RATIO = 0.1
|
RAND_RATIO = 0.15
|
||||||
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
MASK_PROBS = [0.5, 0.5] # 纯随机,分块随机
|
||||||
|
|
||||||
device = torch.device(
|
device = torch.device(
|
||||||
@ -77,7 +77,7 @@ def train():
|
|||||||
|
|
||||||
args = parse_arguments()
|
args = parse_arguments()
|
||||||
|
|
||||||
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL).to(device)
|
model = GinkaMaskGIT(num_classes=NUM_CLASSES, heatmap_channel=HEATMAP_CHANNEL, num_layers=3, d_model=128).to(device)
|
||||||
masker = MapMask([0.5, 0.5])
|
masker = MapMask([0.5, 0.5])
|
||||||
|
|
||||||
dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
|
dataset = GinkaMaskGITDataset(args.train, sigma_rand=RAND_RATIO, blur_min=BLUR_MIN_SIZE, blur_max=BLUR_MAX_SIZE)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user