mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
73 lines
2.1 KiB
Python
73 lines
2.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class ChannelAttention(nn.Module):
|
|
"""通道注意力模块"""
|
|
def __init__(self, channels, reduction=8):
|
|
super().__init__()
|
|
# 通道注意力
|
|
self.channel_att = nn.Sequential(
|
|
nn.AdaptiveAvgPool2d(1),
|
|
nn.Conv2d(channels, channels//reduction, 1),
|
|
nn.ELU(),
|
|
nn.Conv2d(channels//reduction, channels, 1),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x):
|
|
# 通道注意力
|
|
c_att = self.channel_att(x)
|
|
x = x * c_att
|
|
return x
|
|
|
|
class SpatialAttention(nn.Module):
|
|
"""空间注意力模块"""
|
|
def __init__(self):
|
|
super().__init__()
|
|
# 空间注意力
|
|
self.spatial_att = nn.Sequential(
|
|
nn.Conv2d(2, 1, 7, padding=3),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x):
|
|
# 空间注意力
|
|
max_pool = torch.max(x, dim=1, keepdim=True)[0]
|
|
avg_pool = torch.mean(x, dim=1, keepdim=True)
|
|
s_att = self.spatial_att(torch.cat([max_pool, avg_pool], dim=1))
|
|
return x * s_att
|
|
|
|
class CBAM(nn.Module):
|
|
"""通道与空间注意力结合"""
|
|
def __init__(self, channels, reduction=8):
|
|
super().__init__()
|
|
# 通道注意力
|
|
self.channel_att = ChannelAttention(channels, reduction)
|
|
# 空间注意力
|
|
self.spatial_att = SpatialAttention()
|
|
|
|
def forward(self, x):
|
|
# 通道注意力
|
|
c_att = self.channel_att(x)
|
|
x = x * c_att
|
|
|
|
# 空间注意力
|
|
s_att = self.spatial_att(x)
|
|
return x * s_att
|
|
|
|
class SEBlock(nn.Module):
|
|
def __init__(self, channel, reduction=4):
|
|
super().__init__()
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(channel, channel // reduction),
|
|
nn.GELU(),
|
|
nn.Linear(channel // reduction, channel),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, x):
|
|
b, c, _, _ = x.size()
|
|
y = self.avg_pool(x).view(b, c)
|
|
y = self.fc(y).view(b, c, 1, 1)
|
|
return x * y |