feat: 谱归一化

This commit is contained in:
unanmed 2025-03-23 12:53:28 +08:00
parent 0f6613ebaf
commit be20925fc8
9 changed files with 24 additions and 16 deletions

View File

@ -1,7 +1,7 @@
i = $1
i=$1
while true
do
sh gan.sh "$i"
((i++))
i=$((i+1))
echo "$i 次循环完成"
done

6
gan.sh
View File

@ -1,5 +1,5 @@
# 训练部分
python3 -m minamo.train --epochs 30 --resume true
python3 -m minamo.train --epochs 10 --resume true
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
python3 -m minamo.train --epochs 10 --resume true
python3 -m ginka.train --epochs 10 --resume true
@ -8,8 +8,8 @@ python3 -m ginka.validate
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
cd data
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:2
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:2
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json"
cd ..

View File

@ -5,7 +5,7 @@ from .unet import GinkaUNet
from .sample import MapDownSample
class GinkaModel(nn.Module):
def __init__(self, feat_dim=256, base_ch=32, num_classes=32):
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
"""Ginka Model 模型定义部分
"""
super().__init__()
@ -27,6 +27,6 @@ class GinkaModel(nn.Module):
x = self.fc(feat)
x = x.view(-1, self.base_ch, 32, 32)
x = self.unet(x)
x = self.pool(x)
x = F.interpolate(x, (13, 13), mode='bilinear')
return x, F.softmax(x, dim=1)

View File

@ -48,7 +48,7 @@ def train():
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
optimizer = optim.AdamW(model.parameters(), lr=5e-3)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = GinkaLoss(minamo)

View File

@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
def validate():
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
model = GinkaModel()
state = torch.load("result/ginka_checkpoint/10.pth", map_location=device)["model_state"]
state = torch.load("result/ginka.pth", map_location=device)["model_state"]
model.load_state_dict(state)
model.to(device)

View File

@ -4,7 +4,7 @@ import torch.nn.functional as F
from torch.utils.data import Dataset
from shared.graph import convert_soft_map_to_graph
def random_smooth_onehot(onehot_map, min_main=0.8, max_main=1.0, epsilon=0.8):
def random_smooth_onehot(onehot_map, min_main=0.65, max_main=1.0, epsilon=0.35):
"""
生成随机平滑的 one-hot 编码使主类别概率不再固定而是随机波动
"""

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch_geometric.nn import global_mean_pool, TopKPooling, GATConv
from torch_geometric.data import Data
@ -18,6 +19,12 @@ class MinamoTopoModel(nn.Module):
self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2)
self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False)
self.conv1.lin = spectral_norm(self.conv1.lin)
self.conv2.lin = spectral_norm(self.conv2.lin)
self.conv_ins2.lin = spectral_norm(self.conv_ins2.lin)
self.conv_ins1.lin = spectral_norm(self.conv_ins1.lin)
self.conv3.lin = spectral_norm(self.conv3.lin)
# 正则化
self.norm1 = nn.LayerNorm(hidden_dim*16)
self.norm2 = nn.LayerNorm(hidden_dim*16)

View File

@ -1,6 +1,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from shared.attention import CBAM
class MinamoVisionModel(nn.Module):
@ -11,26 +12,26 @@ class MinamoVisionModel(nn.Module):
# 卷积部分
self.vision_conv = nn.Sequential(
nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1),
spectral_norm(nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1)),
nn.BatchNorm2d(conv_ch*2),
CBAM(conv_ch*2),
nn.GELU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.4),
nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1),
spectral_norm(nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1)),
nn.BatchNorm2d(conv_ch*4),
CBAM(conv_ch*4),
nn.GELU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.4),
nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1),
spectral_norm(nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1)),
nn.BatchNorm2d(conv_ch*8),
CBAM(conv_ch*8),
nn.GELU(),
nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1),
spectral_norm(nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1)),
nn.BatchNorm2d(conv_ch*8),
CBAM(conv_ch*8),
nn.GELU(),

View File

@ -56,7 +56,7 @@ def train():
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2, eta_min=1e-6)
criterion = MinamoLoss()
if args.resume:
@ -126,7 +126,7 @@ def train():
scheduler.step()
# 每十轮推理一次验证集
if (epoch + 1) % 5 == 0:
if (epoch + 1) % 1 == 0:
model.eval()
val_loss = 0
with torch.no_grad():