feat: minamo vision 修改为 CBAM 注意力

This commit is contained in:
unanmed 2025-03-17 20:56:42 +08:00
parent ef9d3d1504
commit c8d5c84ee5

View File

@ -1,37 +1,10 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from shared.attention import CBAM
class DualAttention(nn.Module):
def __init__(self, in_channels, reduction=8):
super().__init__()
self.spatial = nn.Sequential(
nn.Conv2d(in_channels, 1, 3, padding=1),
nn.Sigmoid()
)
self.channel = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // reduction, 1),
nn.ReLU(),
nn.Conv2d(in_channels // reduction, in_channels, 1),
nn.Sigmoid()
)
self.channel_max = nn.Sequential(
nn.AdaptiveMaxPool2d(1),
nn.Conv2d(in_channels, in_channels // reduction, 1),
nn.ReLU(),
nn.Conv2d(in_channels // reduction, in_channels, 1),
nn.Sigmoid()
)
def forward(self, x):
attn = self.spatial(x) + self.channel(x) + self.channel_max(x)
return x * attn
class MinamoVisionModel(nn.Module): class MinamoVisionModel(nn.Module):
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=64, out_dim=128): def __init__(self, tile_types=32, embedding_dim=32, conv_channels=64, out_dim=128):
super().__init__() super().__init__()
# 嵌入层处理不同图块类型 # 嵌入层处理不同图块类型
self.embedding = nn.Embedding(tile_types, embedding_dim) self.embedding = nn.Embedding(tile_types, embedding_dim)
@ -40,21 +13,21 @@ class MinamoVisionModel(nn.Module):
self.vision_conv = nn.Sequential( self.vision_conv = nn.Sequential(
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1), nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
nn.BatchNorm2d(conv_channels), nn.BatchNorm2d(conv_channels),
DualAttention(conv_channels, reduction=12), CBAM(conv_channels),
nn.ReLU(), nn.ReLU(),
nn.MaxPool2d(2, 2), nn.MaxPool2d(2),
nn.Dropout2d(0.4), nn.Dropout2d(0.3),
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1), nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
nn.BatchNorm2d(conv_channels*2), nn.BatchNorm2d(conv_channels*2),
DualAttention(conv_channels*2, reduction=12), CBAM(conv_channels*2),
nn.ReLU(), nn.ReLU(),
nn.MaxPool2d(2, 2), nn.MaxPool2d(2),
nn.Dropout2d(0.4), nn.Dropout2d(0.3),
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1), nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
nn.BatchNorm2d(conv_channels*4), nn.BatchNorm2d(conv_channels*4),
DualAttention(conv_channels*4, reduction=12), CBAM(conv_channels*4),
nn.ReLU(), nn.ReLU(),
nn.AdaptiveMaxPool2d(1) nn.AdaptiveMaxPool2d(1)
@ -62,13 +35,12 @@ class MinamoVisionModel(nn.Module):
# 输出为向量 # 输出为向量
self.vision_head = nn.Sequential( self.vision_head = nn.Sequential(
nn.Dropout(0.5), nn.Dropout(0.4),
nn.Linear(conv_channels*4, out_dim) nn.Linear(conv_channels*4, out_dim)
) )
def forward(self, map): def forward(self, map):
x = self.embedding(map) x = self.embedding(map)
# print(map.shape, x.shape)
x = x.permute(0, 3, 1, 2) x = x.permute(0, 3, 1, 2)
x = self.vision_conv(x) x = self.vision_conv(x)