diff --git a/gan.sh b/gan.sh index c8caf20..5065d5c 100644 --- a/gan.sh +++ b/gan.sh @@ -2,7 +2,7 @@ python3 -m minamo.train --epochs 10 --resume true python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json" python3 -m minamo.train --epochs 10 --resume true -python3 -m ginka.train --epochs 10 --resume true +python3 -m ginka.train --epochs 30 --resume true python3 -m ginka.validate # 训练完毕,处理数据 mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json" diff --git a/ginka/dataset.py b/ginka/dataset.py index add2eda..5bc1c36 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -30,8 +30,9 @@ class GinkaDataset(Dataset): item = self.data[idx] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - target_smooth = random_smooth_onehot(target).to(self.device) + target_smooth = random_smooth_onehot(target) graph = differentiable_convert_to_data(target_smooth).to(self.device) + target = target.to(self.device) vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph) return {