fix: 训练报错

This commit is contained in:
unanmed 2025-03-30 17:01:59 +08:00
parent 5669f49af0
commit 4f7dbb6fb3
4 changed files with 9 additions and 7 deletions

2
gan.sh
View File

@ -2,7 +2,7 @@
python3 -m minamo.train --epochs 10 --resume true
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
python3 -m minamo.train --epochs 10 --resume true
python3 -m ginka.train --epochs 70 --resume true
python3 -m ginka.train --epochs 30 --resume true
python3 -m ginka.validate
# 训练完毕,处理数据
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"

View File

@ -75,7 +75,7 @@ def train():
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
optimizer.zero_grad()
noise = torch.randn((BATCH_SIZE, 1, 32, 32))
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
_, output_softmax = model(noise, feat_vec)
# 计算损失
@ -108,7 +108,8 @@ def train():
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
output, output_softmax = model(feat_vec)
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
output, output_softmax = model(noise, feat_vec)
print(torch.argmax(output, dim=1)[0])
# 计算损失

View File

@ -108,7 +108,8 @@ def validate():
target_topo_feat = batch["target_topo_feat"].to(device)
feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1)
# 前向传播
output, output_softmax = model(feat_vec)
noise = torch.randn((target.shape[0], 1, 32, 32)).to(device)
output, output_softmax = model(noise, feat_vec)
map_matrix = torch.argmax(output, dim=1)
for matrix in map_matrix[:].cpu():

View File

@ -29,9 +29,9 @@ class MinamoDataset(Dataset):
map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W]
min_main = random.uniform(0.7, 1)
max_main = random.uniform(0.9, 1)
epsilon = random.uniform(0, 0.3)
min_main = random.uniform(0.6, 1)
max_main = random.uniform(0.8, 1)
epsilon = random.uniform(0, 0.4)
map1_probs = random_smooth_onehot(map1_probs, min_main, max_main, epsilon)
map2_probs = random_smooth_onehot(map2_probs, min_main, max_main, epsilon)