mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-18 07:31:11 +08:00
122 lines
4.0 KiB
Python
122 lines
4.0 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
class DualAttention(nn.Module):
|
||
def __init__(self, in_channels):
|
||
super().__init__()
|
||
# 空间注意力
|
||
self.spatial = nn.Sequential(
|
||
nn.Conv2d(in_channels, 1, 1),
|
||
nn.Sigmoid()
|
||
)
|
||
# 通道注意力
|
||
self.channel = nn.Sequential(
|
||
nn.AdaptiveAvgPool2d(1),
|
||
nn.Conv2d(in_channels, in_channels//8, 1),
|
||
nn.ReLU(),
|
||
nn.Conv2d(in_channels//8, in_channels, 1),
|
||
nn.Sigmoid()
|
||
)
|
||
|
||
def forward(self, x):
|
||
return x * self.spatial(x) + x * self.channel(x)
|
||
|
||
class DirectionalAttention(nn.Module):
|
||
def __init__(self, kernel_size=7):
|
||
super().__init__()
|
||
self.direction_convs = nn.ModuleDict({
|
||
dir: nn.Conv2d(1, 1, kernel_size, padding=kernel_size//2,
|
||
padding_mode='replicate')
|
||
for dir in ['h', 'v', 'd1', 'd2']
|
||
})
|
||
|
||
def forward(self, x):
|
||
B, C, H, W = x.shape
|
||
# 各方向特征
|
||
h_att = self.direction_convs['h'](x.mean(1, keepdim=True))
|
||
v_att = self.direction_convs['v'](x.mean(1, keepdim=True))
|
||
d1_att = self.direction_convs['d1'](x.mean(1, keepdim=True))
|
||
d2_att = self.direction_convs['d2'](x.mean(1, keepdim=True))
|
||
|
||
# 动态融合
|
||
combined = torch.stack([h_att, v_att, d1_att, d2_att], dim=1) # [B,4,1,H,W]
|
||
att_weights = F.softmax(combined.mean([3,4]), dim=1) # [B,4]
|
||
return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1)
|
||
|
||
class MinamoModel(nn.Module):
|
||
def __init__(self, num_tile_types, embedding_dim=64, conv_channels=256):
|
||
super().__init__()
|
||
# 嵌入层处理不同图块类型
|
||
self.embedding = nn.Embedding(num_tile_types, embedding_dim)
|
||
|
||
# 共享特征提取的卷积层
|
||
self.conv_layers = nn.Sequential(
|
||
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
||
DualAttention(conv_channels),
|
||
DirectionalAttention(),
|
||
nn.ReLU(),
|
||
nn.BatchNorm2d(conv_channels),
|
||
|
||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
||
DualAttention(conv_channels*2),
|
||
DirectionalAttention(),
|
||
nn.ReLU(),
|
||
nn.BatchNorm2d(conv_channels*2),
|
||
|
||
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
||
DualAttention(conv_channels*4),
|
||
DirectionalAttention(),
|
||
nn.ReLU(),
|
||
nn.BatchNorm2d(conv_channels*4),
|
||
)
|
||
|
||
# 自适应池化处理任意尺寸
|
||
self.pool = nn.ModuleDict({
|
||
'avg': nn.AdaptiveAvgPool2d((1,1)),
|
||
'max': nn.AdaptiveMaxPool2d((1,1))
|
||
})
|
||
|
||
# 多任务预测头
|
||
head_dim = conv_channels * 4 * 2 * 4 # 2个池化,四个交互项
|
||
self.vision_head = nn.Sequential(
|
||
nn.Linear(head_dim, 512),
|
||
nn.ReLU(),
|
||
nn.Dropout(0.3),
|
||
nn.Linear(512, 1),
|
||
nn.Sigmoid()
|
||
)
|
||
|
||
self.topo_head = nn.Sequential(
|
||
nn.Linear(head_dim, 512),
|
||
nn.GELU(),
|
||
nn.Dropout(0.3),
|
||
nn.Linear(512, 256),
|
||
nn.GELU(),
|
||
nn.Linear(256, 1),
|
||
nn.Sigmoid()
|
||
)
|
||
|
||
|
||
def forward(self, map1, map2):
|
||
# 增强特征提取
|
||
def process_map(x):
|
||
x = self.embedding(x).permute(0,3,1,2)
|
||
x = self.conv_layers(x)
|
||
return torch.cat([
|
||
self.pool['avg'](x),
|
||
self.pool['max'](x)
|
||
], dim=1).flatten(1)
|
||
|
||
f1 = process_map(map1)
|
||
f2 = process_map(map2)
|
||
|
||
# 特征融合
|
||
combined = torch.cat([f1, f2, f1-f2, f1*f2], dim=1) # [B, 256]
|
||
|
||
# 多任务输出
|
||
vision_sim = self.vision_head(combined)
|
||
topo_sim = self.topo_head(combined)
|
||
|
||
return vision_sim, topo_sim
|