mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-19 00:01:13 +08:00
83 lines
2.5 KiB
Python
83 lines
2.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch_geometric.nn import GCNConv, global_mean_pool
|
|
from torch_geometric.data import Data
|
|
|
|
class MinamoSimilarityVision(nn.Module):
|
|
def __init__(self, in_ch, out_ch):
|
|
super().__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_ch, in_ch * 2, 3, padding=1),
|
|
nn.InstanceNorm2d(in_ch * 2),
|
|
nn.ReLU(),
|
|
|
|
nn.Conv2d(in_ch * 2, in_ch * 4, 3, padding=1),
|
|
nn.InstanceNorm2d(in_ch * 4),
|
|
nn.ReLU(),
|
|
|
|
nn.Conv2d(in_ch * 4, in_ch * 8, 3),
|
|
nn.InstanceNorm2d(in_ch * 8),
|
|
nn.ReLU(),
|
|
|
|
nn.AdaptiveAvgPool2d(1)
|
|
)
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(in_ch * 8, out_ch),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
class MinamoSimilarityTopo(nn.Module):
|
|
def __init__(self, in_ch, hidden_dim, out_ch):
|
|
super().__init__()
|
|
self.input_fc = nn.Sequential(
|
|
nn.Linear(in_ch, hidden_dim),
|
|
nn.LayerNorm(hidden_dim),
|
|
nn.ReLU(),
|
|
)
|
|
|
|
self.conv1 = GCNConv(hidden_dim, hidden_dim*2)
|
|
self.conv2 = GCNConv(hidden_dim*2, hidden_dim*4)
|
|
self.conv3 = GCNConv(hidden_dim*4, hidden_dim*8)
|
|
|
|
self.norm1 = nn.LayerNorm(hidden_dim*2)
|
|
self.norm2 = nn.LayerNorm(hidden_dim*4)
|
|
self.norm3 = nn.LayerNorm(hidden_dim*8)
|
|
|
|
self.output_fc = nn.Sequential(
|
|
nn.Linear(hidden_dim*8, out_ch)
|
|
)
|
|
|
|
def forward(self, graph: Data):
|
|
x = self.input_fc(graph.x)
|
|
|
|
x = self.conv1(x, graph.edge_index)
|
|
x = F.relu(self.norm1(x))
|
|
|
|
x = self.conv2(x, graph.edge_index)
|
|
x = F.relu(self.norm2(x))
|
|
|
|
x = self.conv3(x, graph.edge_index)
|
|
x = F.relu(self.norm3(x))
|
|
|
|
x = global_mean_pool(x, graph.batch)
|
|
x = self.output_fc(x)
|
|
|
|
return x
|
|
|
|
class MinamoSimilarityModel(nn.Module):
|
|
def __init__(self, tile_type=32):
|
|
super().__init__()
|
|
self.vision = MinamoSimilarityVision(tile_type, 512)
|
|
self.topo = MinamoSimilarityTopo(tile_type, 64, 512)
|
|
|
|
def forward(self, x, graph):
|
|
vis_feat = self.vision(x)
|
|
topo_feat = self.topo(graph)
|
|
return vis_feat, topo_feat
|
|
|