mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-14 04:41:12 +08:00
feat: minamo vision 修改为 CBAM 注意力
This commit is contained in:
parent
ef9d3d1504
commit
c8d5c84ee5
@ -1,37 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from shared.attention import CBAM
|
||||||
class DualAttention(nn.Module):
|
|
||||||
def __init__(self, in_channels, reduction=8):
|
|
||||||
super().__init__()
|
|
||||||
self.spatial = nn.Sequential(
|
|
||||||
nn.Conv2d(in_channels, 1, 3, padding=1),
|
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.channel = nn.Sequential(
|
|
||||||
nn.AdaptiveAvgPool2d(1),
|
|
||||||
nn.Conv2d(in_channels, in_channels // reduction, 1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(in_channels // reduction, in_channels, 1),
|
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
self.channel_max = nn.Sequential(
|
|
||||||
nn.AdaptiveMaxPool2d(1),
|
|
||||||
nn.Conv2d(in_channels, in_channels // reduction, 1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Conv2d(in_channels // reduction, in_channels, 1),
|
|
||||||
nn.Sigmoid()
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
attn = self.spatial(x) + self.channel(x) + self.channel_max(x)
|
|
||||||
return x * attn
|
|
||||||
|
|
||||||
class MinamoVisionModel(nn.Module):
|
class MinamoVisionModel(nn.Module):
|
||||||
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=64, out_dim=128):
|
def __init__(self, tile_types=32, embedding_dim=32, conv_channels=64, out_dim=128):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# 嵌入层处理不同图块类型
|
# 嵌入层处理不同图块类型
|
||||||
self.embedding = nn.Embedding(tile_types, embedding_dim)
|
self.embedding = nn.Embedding(tile_types, embedding_dim)
|
||||||
@ -40,21 +13,21 @@ class MinamoVisionModel(nn.Module):
|
|||||||
self.vision_conv = nn.Sequential(
|
self.vision_conv = nn.Sequential(
|
||||||
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
||||||
nn.BatchNorm2d(conv_channels),
|
nn.BatchNorm2d(conv_channels),
|
||||||
DualAttention(conv_channels, reduction=12),
|
CBAM(conv_channels),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.MaxPool2d(2, 2),
|
nn.MaxPool2d(2),
|
||||||
nn.Dropout2d(0.4),
|
nn.Dropout2d(0.3),
|
||||||
|
|
||||||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
||||||
nn.BatchNorm2d(conv_channels*2),
|
nn.BatchNorm2d(conv_channels*2),
|
||||||
DualAttention(conv_channels*2, reduction=12),
|
CBAM(conv_channels*2),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.MaxPool2d(2, 2),
|
nn.MaxPool2d(2),
|
||||||
nn.Dropout2d(0.4),
|
nn.Dropout2d(0.3),
|
||||||
|
|
||||||
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
||||||
nn.BatchNorm2d(conv_channels*4),
|
nn.BatchNorm2d(conv_channels*4),
|
||||||
DualAttention(conv_channels*4, reduction=12),
|
CBAM(conv_channels*4),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
|
|
||||||
nn.AdaptiveMaxPool2d(1)
|
nn.AdaptiveMaxPool2d(1)
|
||||||
@ -62,13 +35,12 @@ class MinamoVisionModel(nn.Module):
|
|||||||
|
|
||||||
# 输出为向量
|
# 输出为向量
|
||||||
self.vision_head = nn.Sequential(
|
self.vision_head = nn.Sequential(
|
||||||
nn.Dropout(0.5),
|
nn.Dropout(0.4),
|
||||||
nn.Linear(conv_channels*4, out_dim)
|
nn.Linear(conv_channels*4, out_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, map):
|
def forward(self, map):
|
||||||
x = self.embedding(map)
|
x = self.embedding(map)
|
||||||
# print(map.shape, x.shape)
|
|
||||||
x = x.permute(0, 3, 1, 2)
|
x = x.permute(0, 3, 1, 2)
|
||||||
|
|
||||||
x = self.vision_conv(x)
|
x = self.vision_conv(x)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user