feat: minamo model 修改为余弦退火+重启调度

This commit is contained in:
unanmed 2025-03-17 22:10:06 +08:00
parent c8d5c84ee5
commit 1ebd03390d
4 changed files with 10 additions and 10 deletions

View File

@ -1,7 +1,7 @@
import torch.nn as nn
class MinamoLoss(nn.Module):
def __init__(self, vision_weight=1, topo_weight=0):
def __init__(self, vision_weight=0.4, topo_weight=0.6):
super().__init__()
self.vision_weight = vision_weight
self.topo_weight = topo_weight

View File

@ -14,21 +14,21 @@ class MinamoVisionModel(nn.Module):
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
nn.BatchNorm2d(conv_channels),
CBAM(conv_channels),
nn.ReLU(),
nn.GELU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.3),
nn.Dropout2d(0.4),
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
nn.BatchNorm2d(conv_channels*2),
CBAM(conv_channels*2),
nn.ReLU(),
nn.GELU(),
nn.MaxPool2d(2),
nn.Dropout2d(0.3),
nn.Dropout2d(0.4),
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
nn.BatchNorm2d(conv_channels*4),
CBAM(conv_channels*4),
nn.ReLU(),
nn.GELU(),
nn.AdaptiveMaxPool2d(1)
)

View File

@ -12,7 +12,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True)
os.makedirs("result/minamo_checkpoint", exist_ok=True)
epochs = 100
epochs = 150
def collate_fn(batch):
"""动态处理不同尺寸地图的批处理"""
@ -52,8 +52,8 @@ def train():
)
# 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6)
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = MinamoLoss()
# 开始训练

View File

@ -9,7 +9,7 @@ class ChannelAttention(nn.Module):
self.channel_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels//reduction, 1),
nn.ReLU(),
nn.GELU(),
nn.Conv2d(channels//reduction, channels, 1),
nn.Sigmoid()
)