mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-21 02:11:13 +08:00
fix: device 不正确
This commit is contained in:
parent
bf5160edac
commit
d0decfc63a
@ -417,4 +417,4 @@ class RNNGinkaLoss:
|
|||||||
target: [B, H, W]
|
target: [B, H, W]
|
||||||
"""
|
"""
|
||||||
target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2)
|
target = F.one_hot(target, num_classes=self.num_classes).float().permute(0, 3, 1, 2)
|
||||||
return F.cross_entropy(fake, target, label_smoothing=0.1, weight=self.weight)
|
return F.cross_entropy(fake, target, label_smoothing=0.1)
|
||||||
|
|||||||
@ -72,7 +72,7 @@ class GinkaMapPatch(nn.Module):
|
|||||||
mask[:, 4, 2] = 0
|
mask[:, 4, 2] = 0
|
||||||
mask[:, 4, 3] = 0
|
mask[:, 4, 3] = 0
|
||||||
mask[:, 4, 4] = 0
|
mask[:, 4, 4] = 0
|
||||||
masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5])
|
masked_result = torch.zeros([B, self.tile_classes + 1, 5, 5]).to(map.device)
|
||||||
masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
masked_result[:, 0:32] = F.one_hot(result, num_classes=32).permute(0, 3, 2, 1).float()
|
||||||
masked_result[:, 32] = mask
|
masked_result[:, 32] = mask
|
||||||
|
|
||||||
|
|||||||
@ -81,7 +81,7 @@ def train():
|
|||||||
dataset = GinkaRNNDataset(args.train, device)
|
dataset = GinkaRNNDataset(args.train, device)
|
||||||
dataset_val = GinkaRNNDataset(args.validate, device)
|
dataset_val = GinkaRNNDataset(args.validate, device)
|
||||||
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
||||||
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE)
|
dataloader_val = DataLoader(dataset_val, batch_size=BATCH_SIZE // 8)
|
||||||
|
|
||||||
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
|
optimizer_ginka = optim.AdamW(ginka_rnn.parameters(), lr=1e-4, weight_decay=1e-4)
|
||||||
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
scheduler_ginka = optim.lr_scheduler.CosineAnnealingLR(optimizer_ginka, T_max=800, eta_min=1e-6)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user