diff --git a/minamo/model/loss.py b/minamo/model/loss.py index 3212724..fe99bca 100644 --- a/minamo/model/loss.py +++ b/minamo/model/loss.py @@ -1,7 +1,7 @@ import torch.nn as nn class MinamoLoss(nn.Module): - def __init__(self, vision_weight=1, topo_weight=0): + def __init__(self, vision_weight=0.4, topo_weight=0.6): super().__init__() self.vision_weight = vision_weight self.topo_weight = topo_weight diff --git a/minamo/model/vision.py b/minamo/model/vision.py index 9813b18..4f6c8bb 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -14,21 +14,21 @@ class MinamoVisionModel(nn.Module): nn.Conv2d(embedding_dim, conv_channels, 3, padding=1), nn.BatchNorm2d(conv_channels), CBAM(conv_channels), - nn.ReLU(), + nn.GELU(), nn.MaxPool2d(2), - nn.Dropout2d(0.3), + nn.Dropout2d(0.4), nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), nn.BatchNorm2d(conv_channels*2), CBAM(conv_channels*2), - nn.ReLU(), + nn.GELU(), nn.MaxPool2d(2), - nn.Dropout2d(0.3), + nn.Dropout2d(0.4), nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1), nn.BatchNorm2d(conv_channels*4), CBAM(conv_channels*4), - nn.ReLU(), + nn.GELU(), nn.AdaptiveMaxPool2d(1) ) diff --git a/minamo/train.py b/minamo/train.py index 16b8263..cbe481f 100644 --- a/minamo/train.py +++ b/minamo/train.py @@ -12,7 +12,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") os.makedirs("result", exist_ok=True) os.makedirs("result/minamo_checkpoint", exist_ok=True) -epochs = 100 +epochs = 150 def collate_fn(batch): """动态处理不同尺寸地图的批处理""" @@ -52,8 +52,8 @@ def train(): ) # 设定优化器与调度器 - optimizer = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=1e-2) - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-6) + optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2) + scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2, eta_min=1e-6) criterion = MinamoLoss() # 开始训练 diff --git a/shared/attention.py b/shared/attention.py index e485ee1..c6232d5 100644 --- a/shared/attention.py +++ b/shared/attention.py @@ -9,7 +9,7 @@ class ChannelAttention(nn.Module): self.channel_att = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(channels, channels//reduction, 1), - nn.ReLU(), + nn.GELU(), nn.Conv2d(channels//reduction, channels, 1), nn.Sigmoid() )