mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-15 21:57:52 +08:00
refactor: 拓扑改为图卷积结构
This commit is contained in:
parent
0910bddba2
commit
b43a8693ef
@ -1,6 +1,7 @@
|
||||
import json
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch_geometric.data import Data
|
||||
|
||||
def load_data(path: str):
|
||||
with open(path, 'r', encoding="utf-8") as f:
|
||||
@ -12,6 +13,33 @@ def load_data(path: str):
|
||||
|
||||
return data_list
|
||||
|
||||
def convert_map_to_graph(map):
|
||||
rows = len(map)
|
||||
cols = len(map[0])
|
||||
node_indices = {}
|
||||
valid_nodes = []
|
||||
node_counter = 0
|
||||
|
||||
for r in range(rows):
|
||||
for c in range(cols):
|
||||
if map[r][c] != 1: # 排除墙体
|
||||
node_indices[(r, c)] = node_counter
|
||||
valid_nodes.append((r, c, map[r][c])) # (行, 列, 地形类型)
|
||||
node_counter += 1
|
||||
|
||||
edge_list = []
|
||||
for (r, c, _) in valid_nodes:
|
||||
node = node_indices[(r, c)]
|
||||
if c + 1 < cols and (r, c + 1) in node_indices:
|
||||
edge_list.append((node, node_indices[(r, c + 1)]))
|
||||
if r + 1 < rows and (r + 1, c) in node_indices:
|
||||
edge_list.append((node, node_indices[(r + 1, c)]))
|
||||
|
||||
edge_index = torch.tensor(edge_list, dtype=torch.long).T
|
||||
node_features = torch.tensor([node_type for (_, _, node_type) in valid_nodes], dtype=torch.long)
|
||||
|
||||
return Data(x=node_features, edge_index=edge_index)
|
||||
|
||||
class MinamoDataset(Dataset):
|
||||
def __init__(self, data_path: str):
|
||||
self.data = load_data(data_path) # 自定义数据加载函数
|
||||
@ -25,5 +53,7 @@ class MinamoDataset(Dataset):
|
||||
torch.LongTensor(item['map1']),
|
||||
torch.LongTensor(item['map2']),
|
||||
torch.FloatTensor([item['visionSimilarity']]),
|
||||
torch.FloatTensor([item['topoSimilarity']])
|
||||
torch.FloatTensor([item['topoSimilarity']]),
|
||||
convert_map_to_graph(item['map1']),
|
||||
convert_map_to_graph(item['map2'])
|
||||
)
|
||||
|
||||
@ -8,6 +8,7 @@ class MinamoLoss(nn.Module):
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def forward(self, vis_pred, topo_pred, vis_true, topo_true):
|
||||
# print(vis_pred.shape, topo_pred.shape, vis_true.shape, topo_true.shape)
|
||||
# print(vis_pred[0].item(), topo_pred[0].item(), vis_true[0].item(), topo_true[0].item())
|
||||
vis_loss = self.mse(vis_pred, vis_true)
|
||||
topo_loss = self.mse(topo_pred, topo_true)
|
||||
|
||||
@ -1,114 +1,20 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class DualAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
# 空间注意力
|
||||
self.spatial = nn.Sequential(
|
||||
nn.Conv2d(in_channels, 1, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
# 通道注意力
|
||||
self.channel = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_channels, in_channels//8, 1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels//8, in_channels, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.spatial(x) + x * self.channel(x)
|
||||
|
||||
class DirectionalAttention(nn.Module):
|
||||
def __init__(self, kernel_size=7):
|
||||
super().__init__()
|
||||
self.direction_convs = nn.ModuleDict({
|
||||
dir: nn.Conv2d(1, 1, kernel_size, padding=kernel_size//2,
|
||||
padding_mode='replicate')
|
||||
for dir in ['h', 'v', 'd1', 'd2']
|
||||
})
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
# 各方向特征
|
||||
h_att = self.direction_convs['h'](x.mean(1, keepdim=True))
|
||||
v_att = self.direction_convs['v'](x.mean(1, keepdim=True))
|
||||
d1_att = self.direction_convs['d1'](x.mean(1, keepdim=True))
|
||||
d2_att = self.direction_convs['d2'](x.mean(1, keepdim=True))
|
||||
|
||||
# 动态融合
|
||||
combined = torch.stack([h_att, v_att, d1_att, d2_att], dim=1) # [B,4,1,H,W]
|
||||
att_weights = F.softmax(combined.mean([3,4]), dim=1) # [B,4]
|
||||
return x * (combined * att_weights.unsqueeze(-1).unsqueeze(-1)).sum(1)
|
||||
from .vision import MinamoVisionModel
|
||||
from .topo import MinamoTopoModel
|
||||
|
||||
class MinamoModel(nn.Module):
|
||||
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=32):
|
||||
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=16):
|
||||
super().__init__()
|
||||
# 嵌入层处理不同图块类型
|
||||
self.embedding = nn.Embedding(tile_types, embedding_dim)
|
||||
|
||||
self.vision_conv = nn.Sequential(
|
||||
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
||||
DualAttention(conv_channels),
|
||||
nn.BatchNorm2d(conv_channels),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
||||
DualAttention(conv_channels*2),
|
||||
nn.BatchNorm2d(conv_channels*2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
||||
DualAttention(conv_channels*4),
|
||||
nn.BatchNorm2d(conv_channels*4),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(conv_channels*4, conv_channels*8, 3, padding=1),
|
||||
DualAttention(conv_channels*8),
|
||||
nn.BatchNorm2d(conv_channels*8),
|
||||
nn.ReLU(),
|
||||
nn.AdaptiveAvgPool2d(1)
|
||||
)
|
||||
|
||||
# 拓扑特征分支
|
||||
self.topo_conv = nn.Sequential(
|
||||
nn.Conv2d(embedding_dim, conv_channels, 5, padding=2), # 更大卷积核捕捉结构
|
||||
nn.BatchNorm2d(conv_channels),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(conv_channels, conv_channels*2, 5, padding=2), # 更大卷积核捕捉结构
|
||||
nn.BatchNorm2d(conv_channels*2),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(conv_channels*2, conv_channels*4, 5, padding=2), # 更大卷积核捕捉结构
|
||||
nn.BatchNorm2d(conv_channels*4),
|
||||
nn.ReLU(),
|
||||
# nn.MaxPool2d(2),
|
||||
# GraphConvLayer(128, 256), # 图卷积层
|
||||
nn.AdaptiveMaxPool2d(1)
|
||||
)
|
||||
|
||||
# 多任务预测头
|
||||
self.vision_head = nn.Sequential(
|
||||
nn.Linear(conv_channels*8, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
self.topo_head = nn.Sequential(
|
||||
nn.Linear(conv_channels*4, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
# 视觉相似度部分
|
||||
self.vision_model = MinamoVisionModel(tile_types, embedding_dim, conv_channels)
|
||||
# 拓扑相似度部分
|
||||
self.topo_model = MinamoTopoModel(tile_types)
|
||||
|
||||
def forward(self, map1, map2):
|
||||
e1 = self.embedding(map1).permute(0, 3, 1, 2)
|
||||
e2 = self.embedding(map2).permute(0, 3, 1, 2)
|
||||
def forward(self, map1, map2, graph1, graph2):
|
||||
vision_sim = self.vision_model(map1, map2)
|
||||
|
||||
v1 = self.vision_conv(e1).squeeze()
|
||||
v2 = self.vision_conv(e2).squeeze()
|
||||
topo_feat1 = self.topo_model(graph1)
|
||||
topo_feat2 = self.topo_model(graph2)
|
||||
|
||||
t1 = self.topo_conv(e1).squeeze()
|
||||
t2 = self.topo_conv(e2).squeeze()
|
||||
|
||||
# 多任务输出
|
||||
vision_sim = self.vision_head(torch.abs(v1 - v2))
|
||||
topo_sim = self.topo_head(torch.abs(t1 - t2))
|
||||
|
||||
return vision_sim, topo_sim
|
||||
return vision_sim, F.cosine_similarity(topo_feat1, topo_feat2, -1).unsqueeze(-1)
|
||||
|
||||
29
minamo/model/topo.py
Normal file
29
minamo/model/topo.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import GCNConv, global_mean_pool
|
||||
from torch_geometric.data import Data
|
||||
|
||||
class MinamoTopoModel(nn.Module):
|
||||
def __init__(
|
||||
self, tile_types=32, emb_dim=16, hidden_dim=32, out_dim=16, mlp_dim=8
|
||||
):
|
||||
super().__init__()
|
||||
# 嵌入层
|
||||
self.embedding = torch.nn.Embedding(tile_types, emb_dim)
|
||||
# 图卷积层
|
||||
self.conv1 = GCNConv(emb_dim, hidden_dim)
|
||||
self.conv2 = GCNConv(hidden_dim, out_dim)
|
||||
self.fc = torch.nn.Linear(out_dim, mlp_dim) # 降维全连接层
|
||||
|
||||
def forward(self, graph: Data):
|
||||
x = self.embedding(graph.x)
|
||||
x = self.conv1(x, graph.edge_index)
|
||||
x = F.relu(x)
|
||||
x = self.conv2(x, graph.edge_index)
|
||||
x = global_mean_pool(x, graph.batch)
|
||||
|
||||
# 全连接层降维
|
||||
x = self.fc(x)
|
||||
return x # (batch_size, mlp_dim)
|
||||
|
||||
71
minamo/model/vision.py
Normal file
71
minamo/model/vision.py
Normal file
@ -0,0 +1,71 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class DualAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
# 空间注意力
|
||||
self.spatial = nn.Sequential(
|
||||
nn.Conv2d(in_channels, 1, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
# 通道注意力
|
||||
self.channel = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(in_channels, in_channels//8, 1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(in_channels//8, in_channels, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.spatial(x) + x * self.channel(x)
|
||||
|
||||
class MinamoVisionModel(nn.Module):
|
||||
def __init__(self, tile_types=32, embedding_dim=16, conv_channels=16):
|
||||
super().__init__()
|
||||
# 嵌入层处理不同图块类型
|
||||
self.embedding = nn.Embedding(tile_types, embedding_dim)
|
||||
|
||||
# 卷积部分
|
||||
self.vision_conv = nn.Sequential(
|
||||
nn.Conv2d(embedding_dim, conv_channels, 3, padding=1),
|
||||
DualAttention(conv_channels),
|
||||
nn.BatchNorm2d(conv_channels),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(conv_channels, conv_channels*2, 3, padding=1),
|
||||
DualAttention(conv_channels*2),
|
||||
nn.BatchNorm2d(conv_channels*2),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.Conv2d(conv_channels*2, conv_channels*4, 3, padding=1),
|
||||
DualAttention(conv_channels*4),
|
||||
nn.BatchNorm2d(conv_channels*4),
|
||||
nn.ReLU(),
|
||||
|
||||
nn.AdaptiveAvgPool2d(1)
|
||||
)
|
||||
|
||||
# 预测头
|
||||
self.vision_head = nn.Sequential(
|
||||
nn.Linear(conv_channels*4, conv_channels*2),
|
||||
nn.Dropout(0.4),
|
||||
nn.Linear(conv_channels*2, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
def forward(self, map1, map2):
|
||||
e1 = self.embedding(map1).permute(0, 3, 1, 2)
|
||||
e2 = self.embedding(map2).permute(0, 3, 1, 2)
|
||||
|
||||
v1 = self.vision_conv(e1)
|
||||
v2 = self.vision_conv(e2)
|
||||
|
||||
v1 = v1.view(v1.size(0), -1) # 展平
|
||||
v2 = v2.view(v2.size(0), -1) # 展平
|
||||
|
||||
vision_sim = self.vision_head(torch.abs(v1 - v2))
|
||||
|
||||
return vision_sim
|
||||
@ -2,7 +2,7 @@ import os
|
||||
from datetime import datetime
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
from .model.model import MinamoModel
|
||||
from .model.loss import MinamoLoss
|
||||
@ -63,17 +63,19 @@ def train():
|
||||
|
||||
for batch in dataloader:
|
||||
# 数据迁移到设备
|
||||
map1, map2, vision_simi, topo_simi = batch
|
||||
map1, map2, vision_simi, topo_simi, graph1, graph2 = batch
|
||||
map1 = map1.to(device) # 转为 [B, C, H, W]
|
||||
map2 = map2.to(device)
|
||||
topo_simi = topo_simi.to(device)
|
||||
vision_simi = vision_simi.to(device)
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
|
||||
# print(map1.shape, map2.shape)
|
||||
|
||||
# 前向传播
|
||||
optimizer.zero_grad()
|
||||
vision_pred, topo_pred = model(map1, map2)
|
||||
vision_pred, topo_pred = model(map1, map2, graph1, graph2)
|
||||
|
||||
# 计算损失
|
||||
loss = criterion(vision_pred, topo_pred, vision_simi, topo_simi)
|
||||
@ -103,13 +105,15 @@ def train():
|
||||
val_loss = 0
|
||||
with torch.no_grad():
|
||||
for val_batch in val_loader:
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val = val_batch
|
||||
map1_val, map2_val, vision_simi_val, topo_simi_val, graph1, graph2 = val_batch
|
||||
map1_val = map1_val.to(device)
|
||||
map2_val = map2_val.to(device)
|
||||
vision_simi_val = vision_simi_val.to(device)
|
||||
topo_simi_val = topo_simi_val.to(device)
|
||||
graph1 = graph1.to(device)
|
||||
graph2 = graph2.to(device)
|
||||
|
||||
vision_pred_val, topo_pred_val = model(map1_val, map2_val)
|
||||
vision_pred_val, topo_pred_val = model(map1_val, map2_val, graph1, graph2)
|
||||
loss_val = criterion(
|
||||
vision_pred_val, topo_pred_val,
|
||||
vision_simi_val, topo_simi_val
|
||||
|
||||
6
requirements.txt
Normal file
6
requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
torch
|
||||
torchvision
|
||||
torchaudio
|
||||
tqdm
|
||||
torch-geometric
|
||||
transformers
|
||||
Loading…
Reference in New Issue
Block a user