mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 15:01:10 +08:00
feat: minamo model 修改为余弦退火+重启调度
This commit is contained in:
parent
c8d5c84ee5
commit
1ebd03390d
@ -1,7 +1,7 @@
|
||||
import torch.nn as nn
|
||||
|
||||
class MinamoLoss(nn.Module):
|
||||
def __init__(self, vision_weight=1, topo_weight=0):
|
||||
def __init__(self, vision_weight=0.4, topo_weight=0.6):
|
||||
super().__init__()
|
||||
self.vision_weight = vision_weight
|
||||
self.topo_weight = topo_weight
|
||||
|
||||
@ -14,21 +14,21 @@ class MinamoVisionModel(nn.Module):
|
||||
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_channels),
|
||||
CBAM(conv_channels),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Dropout2d(0.3),
|
||||
nn.Dropout2d(0.4),
|
||||
|
||||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_channels*2),
|
||||
CBAM(conv_channels*2),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
nn.MaxPool2d(2),
|
||||
nn.Dropout2d(0.3),
|
||||
nn.Dropout2d(0.4),
|
||||
|
||||
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
||||
nn.BatchNorm2d(conv_channels*4),
|
||||
CBAM(conv_channels*4),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
|
||||
nn.AdaptiveMaxPool2d(1)
|
||||
)
|
||||
|
||||
@ -12,7 +12,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
os.makedirs("result", exist_ok=True)
|
||||
os.makedirs("result/minamo_checkpoint", exist_ok=True)
|
||||
|
||||
epochs = 100
|
||||
epochs = 150
|
||||
|
||||
def collate_fn(batch):
|
||||
"""动态处理不同尺寸地图的批处理"""
|
||||
@ -52,8 +52,8 @@ def train():
|
||||
)
|
||||
|
||||
# 设定优化器与调度器
|
||||
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
|
||||
criterion = MinamoLoss()
|
||||
|
||||
# 开始训练
|
||||
|
||||
@ -9,7 +9,7 @@ class ChannelAttention(nn.Module):
|
||||
self.channel_att = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(channels, channels//reduction, 1),
|
||||
nn.ReLU(),
|
||||
nn.GELU(),
|
||||
nn.Conv2d(channels//reduction, channels, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user