From 49ee5437327b93fa7b185ca115c39e95646b0ee7 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Wed, 26 Mar 2025 22:44:11 +0800 Subject: [PATCH] =?UTF-8?q?fix:=20=E6=95=B0=E6=8D=AE=E9=9B=86=E6=8A=A5?= =?UTF-8?q?=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gan.sh | 2 +- ginka/dataset.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/gan.sh b/gan.sh index c8caf20..5065d5c 100644 --- a/gan.sh +++ b/gan.sh @@ -2,7 +2,7 @@ python3 -m minamo.train --epochs 10 --resume true python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json" python3 -m minamo.train --epochs 10 --resume true -python3 -m ginka.train --epochs 10 --resume true +python3 -m ginka.train --epochs 30 --resume true python3 -m ginka.validate # 训练完毕,处理数据 mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json" diff --git a/ginka/dataset.py b/ginka/dataset.py index add2eda..5bc1c36 100644 --- a/ginka/dataset.py +++ b/ginka/dataset.py @@ -30,8 +30,9 @@ class GinkaDataset(Dataset): item = self.data[idx] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - target_smooth = random_smooth_onehot(target).to(self.device) + target_smooth = random_smooth_onehot(target) graph = differentiable_convert_to_data(target_smooth).to(self.device) + target = target.to(self.device) vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph) return {