feat: minamo model 修改为余弦退火+重启调度

This commit is contained in:
unanmed 2025-03-17 22:10:06 +08:00
parent c8d5c84ee5
commit 1ebd03390d
4 changed files with 10 additions and 10 deletions

View File

@ -1,7 +1,7 @@
import torch.nn as nn import torch.nn as nn
class MinamoLoss(nn.Module): class MinamoLoss(nn.Module):
def __init__(self, vision_weight=1, topo_weight=0): def __init__(self, vision_weight=0.4, topo_weight=0.6):
super().__init__() super().__init__()
self.vision_weight = vision_weight self.vision_weight = vision_weight
self.topo_weight = topo_weight self.topo_weight = topo_weight

View File

@ -14,21 +14,21 @@ class MinamoVisionModel(nn.Module):
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1), nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
nn.BatchNorm2d(conv_channels), nn.BatchNorm2d(conv_channels),
CBAM(conv_channels), CBAM(conv_channels),
nn.ReLU(), nn.GELU(),
nn.MaxPool2d(2), nn.MaxPool2d(2),
nn.Dropout2d(0.3), nn.Dropout2d(0.4),
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
nn.BatchNorm2d(conv_channels*2), nn.BatchNorm2d(conv_channels*2),
CBAM(conv_channels*2), CBAM(conv_channels*2),
nn.ReLU(), nn.GELU(),
nn.MaxPool2d(2), nn.MaxPool2d(2),
nn.Dropout2d(0.3), nn.Dropout2d(0.4),
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1), nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
nn.BatchNorm2d(conv_channels*4), nn.BatchNorm2d(conv_channels*4),
CBAM(conv_channels*4), CBAM(conv_channels*4),
nn.ReLU(), nn.GELU(),
nn.AdaptiveMaxPool2d(1) nn.AdaptiveMaxPool2d(1)
) )

View File

@ -12,7 +12,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.makedirs("result", exist_ok=True) os.makedirs("result", exist_ok=True)
os.makedirs("result/minamo_checkpoint", exist_ok=True) os.makedirs("result/minamo_checkpoint", exist_ok=True)
epochs = 100 epochs = 150
def collate_fn(batch): def collate_fn(batch):
"""动态处理不同尺寸地图的批处理""" """动态处理不同尺寸地图的批处理"""
@ -52,8 +52,8 @@ def train():
) )
# 设定优化器与调度器 # 设定优化器与调度器
optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-2) optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6) scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6)
criterion = MinamoLoss() criterion = MinamoLoss()
# 开始训练 # 开始训练

View File

@ -9,7 +9,7 @@ class ChannelAttention(nn.Module):
self.channel_att = nn.Sequential( self.channel_att = nn.Sequential(
nn.AdaptiveAvgPool2d(1), nn.AdaptiveAvgPool2d(1),
nn.Conv2d(channels, channels//reduction, 1), nn.Conv2d(channels, channels//reduction, 1),
nn.ReLU(), nn.GELU(),
nn.Conv2d(channels//reduction, channels, 1), nn.Conv2d(channels//reduction, channels, 1),
nn.Sigmoid() nn.Sigmoid()
) )