mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-22 19:31:12 +08:00
feat: minamo model 修改为余弦退火+重启调度
This commit is contained in:
parent
c8d5c84ee5
commit
1ebd03390d
@ -1,7 +1,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
class MinamoLoss(nn.Module):
|
class MinamoLoss(nn.Module):
|
||||||
def __init__(self, vision_weight=1, topo_weight=0):
|
def __init__(self, vision_weight=0.4, topo_weight=0.6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vision_weight = vision_weight
|
self.vision_weight = vision_weight
|
||||||
self.topo_weight = topo_weight
|
self.topo_weight = topo_weight
|
||||||
|
|||||||
@ -14,21 +14,21 @@ class MinamoVisionModel(nn.Module):
|
|||||||
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
||||||
nn.BatchNorm2d(conv_channels),
|
nn.BatchNorm2d(conv_channels),
|
||||||
CBAM(conv_channels),
|
CBAM(conv_channels),
|
||||||
nn.ReLU(),
|
nn.GELU(),
|
||||||
nn.MaxPool2d(2),
|
nn.MaxPool2d(2),
|
||||||
nn.Dropout2d(0.3),
|
nn.Dropout2d(0.4),
|
||||||
|
|
||||||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
||||||
nn.BatchNorm2d(conv_channels*2),
|
nn.BatchNorm2d(conv_channels*2),
|
||||||
CBAM(conv_channels*2),
|
CBAM(conv_channels*2),
|
||||||
nn.ReLU(),
|
nn.GELU(),
|
||||||
nn.MaxPool2d(2),
|
nn.MaxPool2d(2),
|
||||||
nn.Dropout2d(0.3),
|
nn.Dropout2d(0.4),
|
||||||
|
|
||||||
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
||||||
nn.BatchNorm2d(conv_channels*4),
|
nn.BatchNorm2d(conv_channels*4),
|
||||||
CBAM(conv_channels*4),
|
CBAM(conv_channels*4),
|
||||||
nn.ReLU(),
|
nn.GELU(),
|
||||||
|
|
||||||
nn.AdaptiveMaxPool2d(1)
|
nn.AdaptiveMaxPool2d(1)
|
||||||
)
|
)
|
||||||
|
|||||||
@ -12,7 +12,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|||||||
os.makedirs("result", exist_ok=True)
|
os.makedirs("result", exist_ok=True)
|
||||||
os.makedirs("result/minamo_checkpoint", exist_ok=True)
|
os.makedirs("result/minamo_checkpoint", exist_ok=True)
|
||||||
|
|
||||||
epochs = 100
|
epochs = 150
|
||||||
|
|
||||||
def collate_fn(batch):
|
def collate_fn(batch):
|
||||||
"""动态处理不同尺寸地图的批处理"""
|
"""动态处理不同尺寸地图的批处理"""
|
||||||
@ -52,8 +52,8 @@ def train():
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 设定优化器与调度器
|
# 设定优化器与调度器
|
||||||
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-2)
|
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
|
||||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
|
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||||
criterion = MinamoLoss()
|
criterion = MinamoLoss()
|
||||||
|
|
||||||
# 开始训练
|
# 开始训练
|
||||||
|
|||||||
@ -9,7 +9,7 @@ class ChannelAttention(nn.Module):
|
|||||||
self.channel_att = nn.Sequential(
|
self.channel_att = nn.Sequential(
|
||||||
nn.AdaptiveAvgPool2d(1),
|
nn.AdaptiveAvgPool2d(1),
|
||||||
nn.Conv2d(channels, channels//reduction, 1),
|
nn.Conv2d(channels, channels//reduction, 1),
|
||||||
nn.ReLU(),
|
nn.GELU(),
|
||||||
nn.Conv2d(channels//reduction, channels, 1),
|
nn.Conv2d(channels//reduction, channels, 1),
|
||||||
nn.Sigmoid()
|
nn.Sigmoid()
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user