ginka-generator/minamo/model/topo.py

36 lines
1.1 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
class MinamoTopoModel(nn.Module):
def __init__(
self, tile_types=32, emb_dim=128, hidden_dim=256, out_dim=512
):
super().__init__()
# 传入 softmax 概率值,直接映射
self.input_proj = nn.Sequential(
spectral_norm(nn.Linear(tile_types, emb_dim)),
nn.LeakyReLU(0.2)
)
# 图卷积层
self.conv1 = GATConv(emb_dim, hidden_dim, heads=8)
self.conv2 = GATConv(hidden_dim*8, hidden_dim, heads=8)
self.conv3 = GATConv(hidden_dim*8, out_dim, heads=1)
def forward(self, graph: Data):
x = self.input_proj(graph.x)
x = self.conv1(x, graph.edge_index)
x = F.leaky_relu(x, 0.2)
x = self.conv2(x, graph.edge_index)
x = F.leaky_relu(x, 0.2)
x = self.conv3(x, graph.edge_index)
x = F.leaky_relu(x, 0.2)
return x