mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 23:21:20 +08:00
feat: 谱归一化
This commit is contained in:
parent
0f6613ebaf
commit
be20925fc8
4
cycle.sh
4
cycle.sh
@ -1,7 +1,7 @@
|
|||||||
i = $1
|
i=$1
|
||||||
while true
|
while true
|
||||||
do
|
do
|
||||||
sh gan.sh "$i"
|
sh gan.sh "$i"
|
||||||
((i++))
|
i=$((i+1))
|
||||||
echo "第 $i 次循环完成"
|
echo "第 $i 次循环完成"
|
||||||
done
|
done
|
||||||
|
|||||||
6
gan.sh
6
gan.sh
@ -1,5 +1,5 @@
|
|||||||
# 训练部分
|
# 训练部分
|
||||||
python3 -m minamo.train --epochs 30 --resume true
|
python3 -m minamo.train --epochs 10 --resume true
|
||||||
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
|
python3 -m minamo.train --epochs 10 --resume true --train "datasets/minamo-dataset-1.json" --validate "datasets/minamo-eval-1.json"
|
||||||
python3 -m minamo.train --epochs 10 --resume true
|
python3 -m minamo.train --epochs 10 --resume true
|
||||||
python3 -m ginka.train --epochs 10 --resume true
|
python3 -m ginka.train --epochs 10 --resume true
|
||||||
@ -8,8 +8,8 @@ python3 -m ginka.validate
|
|||||||
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
mv "minamo-dataset.json" "datasets/minamo-dataset-$1.json"
|
||||||
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
|
mv "minamo-eval.json" "datasets/minamo-eval-$1.json"
|
||||||
cd data
|
cd data
|
||||||
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned
|
pnpm minamo "../minamo-dataset.json" "../result/ginka_val.json" "../../Apeiria/project" assigned:100:2
|
||||||
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:10
|
pnpm minamo "../minamo-eval.json" "../result/ginka_val.json" "../../Apeiria-eval/project" assigned:100:2
|
||||||
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
|
pnpm merge "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-merged.json" "../datasets/minamo-dataset-$1.json"
|
||||||
pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json"
|
pnpm merge "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-merged.json" "../datasets/minamo-eval-$1.json"
|
||||||
cd ..
|
cd ..
|
||||||
|
|||||||
@ -5,7 +5,7 @@ from .unet import GinkaUNet
|
|||||||
from .sample import MapDownSample
|
from .sample import MapDownSample
|
||||||
|
|
||||||
class GinkaModel(nn.Module):
|
class GinkaModel(nn.Module):
|
||||||
def __init__(self, feat_dim=256, base_ch=32, num_classes=32):
|
def __init__(self, feat_dim=256, base_ch=64, num_classes=32):
|
||||||
"""Ginka Model 模型定义部分
|
"""Ginka Model 模型定义部分
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -27,6 +27,6 @@ class GinkaModel(nn.Module):
|
|||||||
x = self.fc(feat)
|
x = self.fc(feat)
|
||||||
x = x.view(-1, self.base_ch, 32, 32)
|
x = x.view(-1, self.base_ch, 32, 32)
|
||||||
x = self.unet(x)
|
x = self.unet(x)
|
||||||
x = self.pool(x)
|
x = F.interpolate(x, (13, 13), mode='bilinear')
|
||||||
return x, F.softmax(x, dim=1)
|
return x, F.softmax(x, dim=1)
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ def train():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 设定优化器与调度器
|
# 设定优化器与调度器
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
|
optimizer = optim.AdamW(model.parameters(), lr=5e-3)
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||||
criterion = GinkaLoss(minamo)
|
criterion = GinkaLoss(minamo)
|
||||||
|
|
||||||
|
|||||||
@ -66,7 +66,7 @@ def matrix_to_image_cv(map_matrix, tile_set, tile_size=32):
|
|||||||
def validate():
|
def validate():
|
||||||
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
print(f"Using {'cuda' if torch.cuda.is_available() else 'cpu'} to validate model.")
|
||||||
model = GinkaModel()
|
model = GinkaModel()
|
||||||
state = torch.load("result/ginka_checkpoint/10.pth", map_location=device)["model_state"]
|
state = torch.load("result/ginka.pth", map_location=device)["model_state"]
|
||||||
model.load_state_dict(state)
|
model.load_state_dict(state)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import torch.nn.functional as F
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from shared.graph import convert_soft_map_to_graph
|
from shared.graph import convert_soft_map_to_graph
|
||||||
|
|
||||||
def random_smooth_onehot(onehot_map, min_main=0.8, max_main=1.0, epsilon=0.8):
|
def random_smooth_onehot(onehot_map, min_main=0.65, max_main=1.0, epsilon=0.35):
|
||||||
"""
|
"""
|
||||||
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
生成随机平滑的 one-hot 编码,使主类别概率不再固定,而是随机波动
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import spectral_norm
|
||||||
from torch_geometric.nn import global_mean_pool, TopKPooling, GATConv
|
from torch_geometric.nn import global_mean_pool, TopKPooling, GATConv
|
||||||
from torch_geometric.data import Data
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
@ -18,6 +19,12 @@ class MinamoTopoModel(nn.Module):
|
|||||||
self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2)
|
self.conv_ins1 = GATConv(hidden_dim*16, hidden_dim*8, heads=2)
|
||||||
self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False)
|
self.conv3 = GATConv(hidden_dim*16, out_dim, concat=False)
|
||||||
|
|
||||||
|
self.conv1.lin = spectral_norm(self.conv1.lin)
|
||||||
|
self.conv2.lin = spectral_norm(self.conv2.lin)
|
||||||
|
self.conv_ins2.lin = spectral_norm(self.conv_ins2.lin)
|
||||||
|
self.conv_ins1.lin = spectral_norm(self.conv_ins1.lin)
|
||||||
|
self.conv3.lin = spectral_norm(self.conv3.lin)
|
||||||
|
|
||||||
# 正则化
|
# 正则化
|
||||||
self.norm1 = nn.LayerNorm(hidden_dim*16)
|
self.norm1 = nn.LayerNorm(hidden_dim*16)
|
||||||
self.norm2 = nn.LayerNorm(hidden_dim*16)
|
self.norm2 = nn.LayerNorm(hidden_dim*16)
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.nn.utils import spectral_norm
|
||||||
from shared.attention import CBAM
|
from shared.attention import CBAM
|
||||||
|
|
||||||
class MinamoVisionModel(nn.Module):
|
class MinamoVisionModel(nn.Module):
|
||||||
@ -11,26 +12,26 @@ class MinamoVisionModel(nn.Module):
|
|||||||
|
|
||||||
# 卷积部分
|
# 卷积部分
|
||||||
self.vision_conv = nn.Sequential(
|
self.vision_conv = nn.Sequential(
|
||||||
nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1),
|
spectral_norm(nn.Conv2d(conv_ch, conv_ch*2, 3, padding=1)),
|
||||||
nn.BatchNorm2d(conv_ch*2),
|
nn.BatchNorm2d(conv_ch*2),
|
||||||
CBAM(conv_ch*2),
|
CBAM(conv_ch*2),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.MaxPool2d(2),
|
nn.MaxPool2d(2),
|
||||||
nn.Dropout2d(0.4),
|
nn.Dropout2d(0.4),
|
||||||
|
|
||||||
nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1),
|
spectral_norm(nn.Conv2d(conv_ch*2, conv_ch*4, 3, padding=1)),
|
||||||
nn.BatchNorm2d(conv_ch*4),
|
nn.BatchNorm2d(conv_ch*4),
|
||||||
CBAM(conv_ch*4),
|
CBAM(conv_ch*4),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
nn.MaxPool2d(2),
|
nn.MaxPool2d(2),
|
||||||
nn.Dropout2d(0.4),
|
nn.Dropout2d(0.4),
|
||||||
|
|
||||||
nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1),
|
spectral_norm(nn.Conv2d(conv_ch*4, conv_ch*8, 3, padding=1)),
|
||||||
nn.BatchNorm2d(conv_ch*8),
|
nn.BatchNorm2d(conv_ch*8),
|
||||||
CBAM(conv_ch*8),
|
CBAM(conv_ch*8),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
|
|
||||||
nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1),
|
spectral_norm(nn.Conv2d(conv_ch*8, conv_ch*8, 3, padding=1)),
|
||||||
nn.BatchNorm2d(conv_ch*8),
|
nn.BatchNorm2d(conv_ch*8),
|
||||||
CBAM(conv_ch*8),
|
CBAM(conv_ch*8),
|
||||||
nn.GELU(),
|
nn.GELU(),
|
||||||
|
|||||||
@ -56,7 +56,7 @@ def train():
|
|||||||
|
|
||||||
# 设定优化器与调度器
|
# 设定优化器与调度器
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
|
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=3, T_mult=2, eta_min=1e-6)
|
||||||
criterion = MinamoLoss()
|
criterion = MinamoLoss()
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
@ -126,7 +126,7 @@ def train():
|
|||||||
scheduler.step()
|
scheduler.step()
|
||||||
|
|
||||||
# 每十轮推理一次验证集
|
# 每十轮推理一次验证集
|
||||||
if (epoch + 1) % 5 == 0:
|
if (epoch + 1) % 1 == 0:
|
||||||
model.eval()
|
model.eval()
|
||||||
val_loss = 0
|
val_loss = 0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user