mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
fix: 数据集报错
This commit is contained in:
parent
8f892fc7f4
commit
49ee543732
2
gan.sh
2
gan.sh
@ -2,7 +2,7 @@
|
|||||||
python3 -m minamo.train --epochs 10 --resume true
|
python3 -m minamo.train --epochs 10 --resume true
|
||||||
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
|
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
|
||||||
python3 -m minamo.train --epochs 10 --resume true
|
python3 -m minamo.train --epochs 10 --resume true
|
||||||
python3 -m ginka.train --epochs 10 --resume true
|
python3 -m ginka.train --epochs 30 --resume true
|
||||||
python3 -m ginka.validate
|
python3 -m ginka.validate
|
||||||
# 训练完毕,处理数据
|
# 训练完毕,处理数据
|
||||||
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
||||||
|
|||||||
@ -30,8 +30,9 @@ class GinkaDataset(Dataset):
|
|||||||
item = self.data[idx]
|
item = self.data[idx]
|
||||||
|
|
||||||
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
target = F.one_hot(torch.LongTensor(item['map']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||||
target_smooth = random_smooth_onehot(target).to(self.device)
|
target_smooth = random_smooth_onehot(target)
|
||||||
graph = differentiable_convert_to_data(target_smooth).to(self.device)
|
graph = differentiable_convert_to_data(target_smooth).to(self.device)
|
||||||
|
target = target.to(self.device)
|
||||||
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
vision_feat, topo_feat = self.minamo(target.unsqueeze(0), graph)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user