diff --git a/gan.sh b/gan.sh index 4767ca5..5065d5c 100644 --- a/gan.sh +++ b/gan.sh @@ -2,7 +2,7 @@ python3 -m minamo.train --epochs 10 --resume true python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json" python3 -m minamo.train --epochs 10 --resume true -python3 -m ginka.train --epochs 70 --resume true +python3 -m ginka.train --epochs 30 --resume true python3 -m ginka.validate # 训练完毕,处理数据 mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json" diff --git a/ginka/train.py b/ginka/train.py index c339629..9b0dbf7 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -75,7 +75,7 @@ def train(): feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) # 前向传播 optimizer.zero_grad() - noise = torch.randn((BATCH_SIZE, 1, 32, 32)) + noise = torch.randn((target.shape[0], 1, 32, 32)).to(device) _, output_softmax = model(noise, feat_vec) # 计算损失 @@ -108,7 +108,8 @@ def train(): feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) # 前向传播 - output, output_softmax = model(feat_vec) + noise = torch.randn((target.shape[0], 1, 32, 32)).to(device) + output, output_softmax = model(noise, feat_vec) print(torch.argmax(output, dim=1)[0]) # 计算损失 diff --git a/ginka/validate.py b/ginka/validate.py index 89ef941..750707f 100644 --- a/ginka/validate.py +++ b/ginka/validate.py @@ -108,7 +108,8 @@ def validate(): target_topo_feat = batch["target_topo_feat"].to(device) feat_vec = torch.cat([target_vision_feat, target_topo_feat], dim=-1).to(device).squeeze(1) # 前向传播 - output, output_softmax = model(feat_vec) + noise = torch.randn((target.shape[0], 1, 32, 32)).to(device) + output, output_softmax = model(noise, feat_vec) map_matrix = torch.argmax(output, dim=1) for matrix in map_matrix[:].cpu(): diff --git a/minamo/dataset.py b/minamo/dataset.py index 079c4d4..fb99639 100644 --- a/minamo/dataset.py +++ b/minamo/dataset.py @@ -29,9 +29,9 @@ class MinamoDataset(Dataset): map1_probs = F.one_hot(torch.LongTensor(item['map1']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] map2_probs = F.one_hot(torch.LongTensor(item['map2']), num_classes=32).permute(2, 0, 1).float() # [32, H, W] - min_main = random.uniform(0.7, 1) - max_main = random.uniform(0.9, 1) - epsilon = random.uniform(0, 0.3) + min_main = random.uniform(0.6, 1) + max_main = random.uniform(0.8, 1) + epsilon = random.uniform(0, 0.4) map1_probs = random_smooth_onehot(map1_probs, min_main, max_main, epsilon) map2_probs = random_smooth_onehot(map2_probs, min_main, max_main, epsilon)