ginka-generator/minamo/model/vision.py

71 lines
2.3 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
class DualAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
# 空间注意力
self.spatial = nn.Sequential(
nn.Conv2d(in_channels, 1, 1),
nn.Sigmoid()
)
# 通道注意力
self.channel = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels//8, 1),
nn.ReLU(),
nn.Conv2d(in_channels//8, in_channels, 1),
nn.Sigmoid()
)
def forward(self, x):
return x * self.spatial(x) + x * self.channel(x)
class MinamoVisionModel(nn.Module):
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=16):
super().__init__()
# 嵌入层处理不同图块类型
self.embedding = nn.Embedding(tile_types, embedding_dim)
# 卷积部分
self.vision_conv = nn.Sequential(
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
DualAttention(conv_channels),
nn.BatchNorm2d(conv_channels),
nn.ReLU(),
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
DualAttention(conv_channels*2),
nn.BatchNorm2d(conv_channels*2),
nn.ReLU(),
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
DualAttention(conv_channels*4),
nn.BatchNorm2d(conv_channels*4),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1)
)
# 预测头
self.vision_head = nn.Sequential(
nn.Linear(conv_channels*4, conv_channels*2),
nn.Dropout(0.4),
nn.Linear(conv_channels*2, 1),
nn.Sigmoid()
)
def forward(self, map1, map2):
e1 = self.embedding(map1).permute(0, 3, 1, 2)
e2 = self.embedding(map2).permute(0, 3, 1, 2)
v1 = self.vision_conv(e1)
v2 = self.vision_conv(e2)
v1 = v1.view(v1.size(0), -1) # 展平
v2 = v2.view(v2.size(0), -1) # 展平
vision_sim = self.vision_head(torch.abs(v1 - v2))
return vision_sim