diff --git a/cycle.sh b/cycle.sh index c0d1ec2..accd0e7 100644 --- a/cycle.sh +++ b/cycle.sh @@ -1,7 +1,7 @@ -i = $1 +i=$1 while true do sh gan.sh "$i" - ((i++)) + i=$((i+1)) echo "第 $i 次循环完成" done diff --git a/gan.sh b/gan.sh index f701a2e..fc49a37 100644 --- a/gan.sh +++ b/gan.sh @@ -1,5 +1,5 @@ # 训练部分 -python3 -m minamo.train --epochs 30 --resume true +python3 -m minamo.train --epochs 10 --resume true python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json" python3 -m minamo.train --epochs 10 --resume true python3 -m ginka.train --epochs 10 --resume true @@ -8,8 +8,8 @@ python3 -m ginka.validate mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json" mv "minamo-eval.json" "datasets/minamo-eval-$1.json" cd data -pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned -pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10 +pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:2 +pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:2 pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json" pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json" cd .. diff --git a/ginka/model/model.py b/ginka/model/model.py index b93a4fb..06356d7 100644 --- a/ginka/model/model.py +++ b/ginka/model/model.py @@ -5,7 +5,7 @@ from .unet import GinkaUNet from .sample import MapDownSample class GinkaModel(nn.Module): - def __init__(self, feat_dim=256, base_ch=32, num_classes=32): + def __init__(self, feat_dim=256, base_ch=64, num_classes=32): """Ginka Model 模型定义部分 """ super().__init__() @@ -27,6 +27,6 @@ class GinkaModel(nn.Module): x = self.fc(feat) x = x.view(-1, self.base_ch, 32, 32) x = self.unet(x) - x = self.pool(x) + x = F.interpolate(x, (13, 13), mode='bilinear') return x, F.softmax(x, dim=1) \ No newline at end of file diff --git a/ginka/train.py b/ginka/train.py index be1fe97..61f9d1c 100644 --- a/ginka/train.py +++ b/ginka/train.py @@ -48,7 +48,7 @@ def train(): ) # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=1e-3) + optimizer = optim.AdamW(model.parameters(), lr=5e-3) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) criterion = GinkaLoss(minamo) diff --git a/ginka/validate.py b/ginka/validate.py index 0eca862..e100603 100644 --- a/ginka/validate.py +++ b/ginka/validate.py @@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32): def validate(): print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.") model = GinkaModel() - state = torch.load("result/ginka_checkpoint/10.pth", map_location=device)["model_state"] + state = torch.load("result/ginka.pth", map_location=device)["model_state"] model.load_state_dict(state) model.to(device) diff --git a/minamo/dataset.py b/minamo/dataset.py index a543ff5..177fe8d 100644 --- a/minamo/dataset.py +++ b/minamo/dataset.py @@ -4,7 +4,7 @@ import torch.nn.functional as F from torch.utils.data import Dataset from shared.graph import convert_soft_map_to_graph -def random_smooth_onehot(onehot_map, min_main=0.8, max_main=1.0, epsilon=0.8): +def random_smooth_onehot(onehot_map, min_main=0.65, max_main=1.0, epsilon=0.35): """ 生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动 """ diff --git a/minamo/model/topo.py b/minamo/model/topo.py index f20cdb2..931dbc3 100644 --- a/minamo/model/topo.py +++ b/minamo/model/topo.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.utils import spectral_norm from torch_geometric.nn import global_mean_pool, TopKPooling, GATConv from torch_geometric.data import Data @@ -18,6 +19,12 @@ class MinamoTopoModel(nn.Module): self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2) self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False) + self.conv1.lin = spectral_norm(self.conv1.lin) + self.conv2.lin = spectral_norm(self.conv2.lin) + self.conv_ins2.lin = spectral_norm(self.conv_ins2.lin) + self.conv_ins1.lin = spectral_norm(self.conv_ins1.lin) + self.conv3.lin = spectral_norm(self.conv3.lin) + # 正则化 self.norm1 = nn.LayerNorm(hidden_dim*16) self.norm2 = nn.LayerNorm(hidden_dim*16) diff --git a/minamo/model/vision.py b/minamo/model/vision.py index 35126d9..2bd8572 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -1,6 +1,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from torch.nn.utils import spectral_norm from shared.attention import CBAM class MinamoVisionModel(nn.Module): @@ -11,26 +12,26 @@ class MinamoVisionModel(nn.Module): # 卷积部分 self.vision_conv = nn.Sequential( - nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1), + spectral_norm(nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1)), nn.BatchNorm2d(conv_ch*2), CBAM(conv_ch*2), nn.GELU(), nn.MaxPool2d(2), nn.Dropout2d(0.4), - nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1), + spectral_norm(nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1)), nn.BatchNorm2d(conv_ch*4), CBAM(conv_ch*4), nn.GELU(), nn.MaxPool2d(2), nn.Dropout2d(0.4), - nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1), + spectral_norm(nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1)), nn.BatchNorm2d(conv_ch*8), CBAM(conv_ch*8), nn.GELU(), - nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1), + spectral_norm(nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1)), nn.BatchNorm2d(conv_ch*8), CBAM(conv_ch*8), nn.GELU(), diff --git a/minamo/train.py b/minamo/train.py index eaac040..b2aed3e 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -56,7 +56,7 @@ def train(): # 设定优化器与调度器 optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) - scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) + scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2, eta_min=1e-6) criterion = MinamoLoss() if args.resume: @@ -126,7 +126,7 @@ def train(): scheduler.step() # 每十轮推理一次验证集 - if (epoch + 1) % 5 == 0: + if (epoch + 1) % 1 == 0: model.eval() val_loss = 0 with torch.no_grad():