From c8d5c84ee5c1306066d3312bc4108574d697c5c2 Mon Sep 17 00:00:00 2001 From: unanmed <1319491857@qq.com> Date: Mon, 17 Mar 2025 20:56:42 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20minamo=20vision=20=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=E4=B8=BA=20CBAM=20=E6=B3=A8=E6=84=8F=E5=8A=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- minamo/model/vision.py | 48 +++++++++--------------------------------- 1 file changed, 10 insertions(+), 38 deletions(-) diff --git a/minamo/model/vision.py b/minamo/model/vision.py index 38ddbe0..9813b18 100644 --- a/minamo/model/vision.py +++ b/minamo/model/vision.py @@ -1,37 +1,10 @@ import torch import torch.nn as nn import torch.nn.functional as F - -class DualAttention(nn.Module): - def __init__(self, in_channels, reduction=8): - super().__init__() - self.spatial = nn.Sequential( - nn.Conv2d(in_channels, 1, 3, padding=1), - nn.Sigmoid() - ) - - self.channel = nn.Sequential( - nn.AdaptiveAvgPool2d(1), - nn.Conv2d(in_channels, in_channels // reduction, 1), - nn.ReLU(), - nn.Conv2d(in_channels // reduction, in_channels, 1), - nn.Sigmoid() - ) - - self.channel_max = nn.Sequential( - nn.AdaptiveMaxPool2d(1), - nn.Conv2d(in_channels, in_channels // reduction, 1), - nn.ReLU(), - nn.Conv2d(in_channels // reduction, in_channels, 1), - nn.Sigmoid() - ) - - def forward(self, x): - attn = self.spatial(x) + self.channel(x) + self.channel_max(x) - return x * attn +from shared.attention import CBAM class MinamoVisionModel(nn.Module): - def __init__(self, tile_types=32, embedding_dim=16, conv_channels=64, out_dim=128): + def __init__(self, tile_types=32, embedding_dim=32, conv_channels=64, out_dim=128): super().__init__() # 嵌入层处理不同图块类型 self.embedding = nn.Embedding(tile_types, embedding_dim) @@ -40,21 +13,21 @@ class MinamoVisionModel(nn.Module): self.vision_conv = nn.Sequential( nn.Conv2d(embedding_dim, conv_channels, 3, padding=1), nn.BatchNorm2d(conv_channels), - DualAttention(conv_channels, reduction=12), + CBAM(conv_channels), nn.ReLU(), - nn.MaxPool2d(2, 2), - nn.Dropout2d(0.4), + nn.MaxPool2d(2), + nn.Dropout2d(0.3), nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), nn.BatchNorm2d(conv_channels*2), - DualAttention(conv_channels*2, reduction=12), + CBAM(conv_channels*2), nn.ReLU(), - nn.MaxPool2d(2, 2), - nn.Dropout2d(0.4), + nn.MaxPool2d(2), + nn.Dropout2d(0.3), nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1), nn.BatchNorm2d(conv_channels*4), - DualAttention(conv_channels*4, reduction=12), + CBAM(conv_channels*4), nn.ReLU(), nn.AdaptiveMaxPool2d(1) @@ -62,13 +35,12 @@ class MinamoVisionModel(nn.Module): # 输出为向量 self.vision_head = nn.Sequential( - nn.Dropout(0.5), + nn.Dropout(0.4), nn.Linear(conv_channels*4, out_dim) ) def forward(self, map): x = self.embedding(map) - # print(map.shape, x.shape) x = x.permute(0, 3, 1, 2) x = self.vision_conv(x)