fix: 数据集报错

This commit is contained in:
unanmed 2025-03-26 22:44:11 +08:00
parent 8f892fc7f4
commit 49ee543732
2 changed files with 3 additions and 2 deletions

2
gan.sh
View File

@ -2,7 +2,7 @@
python3 -m minamo.train --epochs 10 --resume true python3 -m minamo.train --epochs 10 --resume true
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json" python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
python3 -m minamo.train --epochs 10 --resume true python3 -m minamo.train --epochs 10 --resume true
python3 -m ginka.train --epochs 10 --resume true python3 -m ginka.train --epochs 30 --resume true
python3 -m ginka.validate python3 -m ginka.validate
# 训练完毕,处理数据 # 训练完毕,处理数据
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json" mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"

View File

@ -30,8 +30,9 @@ class GinkaDataset(Dataset):
item = self.data[idx] item = self.data[idx]
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
target_smooth = random_smooth_onehot(target).to(self.device) target_smooth = random_smooth_onehot(target)
graph = differentiable_convert_to_data(target_smooth).to(self.device) graph = differentiable_convert_to_data(target_smooth).to(self.device)
target = target.to(self.device)
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph) vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
return { return {