mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-15 05:11:10 +08:00
fix: 训练报错
This commit is contained in:
parent
5669f49af0
commit
4f7dbb6fb3
2
gan.sh
2
gan.sh
@ -2,7 +2,7 @@
|
||||
python3 -m minamo.train --epochs 10 --resume true
|
||||
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
|
||||
python3 -m minamo.train --epochs 10 --resume true
|
||||
python3 -m ginka.train --epochs 70 --resume true
|
||||
python3 -m ginka.train --epochs 30 --resume true
|
||||
python3 -m ginka.validate
|
||||
# 训练完毕,处理数据
|
||||
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
||||
|
||||
@ -75,7 +75,7 @@ def train():
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
noise = torch.randn((BATCH_SIZE, 1, 32, 32))
|
||||
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
|
||||
_, output_softmax = model(noise, feat_vec)
|
||||
|
||||
# 计算损失
|
||||
@ -108,7 +108,8 @@ def train():
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
|
||||
# 前向传播
|
||||
output, output_softmax = model(feat_vec)
|
||||
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
|
||||
output, output_softmax = model(noise, feat_vec)
|
||||
print(torch.argmax(output, dim=1)[0])
|
||||
|
||||
# 计算损失
|
||||
|
||||
@ -108,7 +108,8 @@ def validate():
|
||||
target_topo_feat = batch["target_topo_feat"].to(device)
|
||||
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
|
||||
# 前向传播
|
||||
output, output_softmax = model(feat_vec)
|
||||
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
|
||||
output, output_softmax = model(noise, feat_vec)
|
||||
map_matrix = torch.argmax(output, dim=1)
|
||||
|
||||
for matrix in map_matrix[:].cpu():
|
||||
|
||||
@ -29,9 +29,9 @@ class MinamoDataset(Dataset):
|
||||
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
|
||||
|
||||
min_main = random.uniform(0.7, 1)
|
||||
max_main = random.uniform(0.9, 1)
|
||||
epsilon = random.uniform(0, 0.3)
|
||||
min_main = random.uniform(0.6, 1)
|
||||
max_main = random.uniform(0.8, 1)
|
||||
epsilon = random.uniform(0, 0.4)
|
||||
|
||||
map1_probs = random_smooth_onehot(map1_probs, min_main, max_main, epsilon)
|
||||
map2_probs = random_smooth_onehot(map2_probs, min_main, max_main, epsilon)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user