mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-15 13:21:09 +08:00
feat: 谱归一化
This commit is contained in:
parent
0f6613ebaf
commit
be20925fc8
4
cycle.sh
4
cycle.sh
@ -1,7 +1,7 @@
|
||||
i = $1
|
||||
i=$1
|
||||
while true
|
||||
do
|
||||
sh gan.sh "$i"
|
||||
((i++))
|
||||
i=$((i+1))
|
||||
echo "第 $i 次循环完成"
|
||||
done
|
||||
|
||||
6
gan.sh
6
gan.sh
@ -1,5 +1,5 @@
|
||||
# 训练部分
|
||||
python3 -m minamo.train --epochs 30 --resume true
|
||||
python3 -m minamo.train --epochs 10 --resume true
|
||||
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
|
||||
python3 -m minamo.train --epochs 10 --resume true
|
||||
python3 -m ginka.train --epochs 10 --resume true
|
||||
@ -8,8 +8,8 @@ python3 -m ginka.validate
|
||||
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
||||
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
|
||||
cd data
|
||||
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned
|
||||
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
|
||||
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:2
|
||||
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:2
|
||||
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
|
||||
pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json"
|
||||
cd ..
|
||||
|
||||
@ -5,7 +5,7 @@ from .unet import GinkaUNet
|
||||
from .sample import MapDownSample
|
||||
|
||||
class GinkaModel(nn.Module):
|
||||
def __init__(self, feat_dim=256, base_ch=32, num_classes=32):
|
||||
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
|
||||
"""Ginka Model 模型定义部分
|
||||
"""
|
||||
super().__init__()
|
||||
@ -27,6 +27,6 @@ class GinkaModel(nn.Module):
|
||||
x = self.fc(feat)
|
||||
x = x.view(-1, self.base_ch, 32, 32)
|
||||
x = self.unet(x)
|
||||
x = self.pool(x)
|
||||
x = F.interpolate(x, (13, 13), mode='bilinear')
|
||||
return x, F.softmax(x, dim=1)
|
||||
|
||||
@ -48,7 +48,7 @@ def train():
|
||||
)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=5e-3)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = GinkaLoss(minamo)
|
||||
|
||||
|
||||
@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
||||
def validate():
|
||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||
model = GinkaModel()
|
||||
state = torch.load("result/ginka_checkpoint/10.pth", map_location=device)["model_state"]
|
||||
state = torch.load("result/ginka.pth", map_location=device)["model_state"]
|
||||
model.load_state_dict(state)
|
||||
model.to(device)
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset
|
||||
from shared.graph import convert_soft_map_to_graph
|
||||
|
||||
def random_smooth_onehot(onehot_map, min_main=0.8, max_main=1.0, epsilon=0.8):
|
||||
def random_smooth_onehot(onehot_map, min_main=0.65, max_main=1.0, epsilon=0.35):
|
||||
"""
|
||||
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
||||
"""
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from torch_geometric.nn import global_mean_pool, TopKPooling, GATConv
|
||||
from torch_geometric.data import Data
|
||||
|
||||
@ -18,6 +19,12 @@ class MinamoTopoModel(nn.Module):
|
||||
self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2)
|
||||
self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False)
|
||||
|
||||
self.conv1.lin = spectral_norm(self.conv1.lin)
|
||||
self.conv2.lin = spectral_norm(self.conv2.lin)
|
||||
self.conv_ins2.lin = spectral_norm(self.conv_ins2.lin)
|
||||
self.conv_ins1.lin = spectral_norm(self.conv_ins1.lin)
|
||||
self.conv3.lin = spectral_norm(self.conv3.lin)
|
||||
|
||||
# 正则化
|
||||
self.norm1 = nn.LayerNorm(hidden_dim*16)
|
||||
self.norm2 = nn.LayerNorm(hidden_dim*16)
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils import spectral_norm
|
||||
from shared.attention import CBAM
|
||||
|
||||
class MinamoVisionModel(nn.Module):
|
||||
@ -11,26 +12,26 @@ class MinamoVisionModel(nn.Module):
|
||||
|
||||
# 卷积部分
|
||||
self.vision_conv = nn.Sequential(
|
||||
nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1),
|
||||
spectral_norm(nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1)),
|
||||
nn.BatchNorm2d(conv_ch*2),
|
||||
CBAM(conv_ch*2),
|
||||
nn.GELU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Dropout2d(0.4),
|
||||
|
||||
nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1),
|
||||
spectral_norm(nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1)),
|
||||
nn.BatchNorm2d(conv_ch*4),
|
||||
CBAM(conv_ch*4),
|
||||
nn.GELU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Dropout2d(0.4),
|
||||
|
||||
nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1),
|
||||
spectral_norm(nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1)),
|
||||
nn.BatchNorm2d(conv_ch*8),
|
||||
CBAM(conv_ch*8),
|
||||
nn.GELU(),
|
||||
|
||||
nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1),
|
||||
spectral_norm(nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1)),
|
||||
nn.BatchNorm2d(conv_ch*8),
|
||||
CBAM(conv_ch*8),
|
||||
nn.GELU(),
|
||||
|
||||
@ -56,7 +56,7 @@ def train():
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2, eta_min=1e-6)
|
||||
criterion = MinamoLoss()
|
||||
|
||||
if args.resume:
|
||||
@ -126,7 +126,7 @@ def train():
|
||||
scheduler.step()
|
||||
|
||||
# 每十轮推理一次验证集
|
||||
if (epoch + 1) % 5 == 0:
|
||||
if (epoch + 1) % 1 == 0:
|
||||
model.eval()
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user