mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 22:41:14 +08:00
37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch_geometric.nn import GCNConv, TransformerConv
|
|
from torch_geometric.utils import grid
|
|
from ..common.cond import ConditionInjector
|
|
|
|
# 考虑使用 GCN 作为生成器主路径,暂时先留着
|
|
|
|
class GCNBlock(nn.Module):
|
|
def __init__(self, feats: tuple[int, int, int]):
|
|
super().__init__()
|
|
self.conv1 = GCNConv(feats[0], feats[1])
|
|
self.conv2 = GCNConv(feats[1], feats[2])
|
|
|
|
self.norm1 = nn.LayerNorm(feats[1])
|
|
self.norm2 = nn.LayerNorm(feats[2])
|
|
|
|
def forward(self, x, edge_index):
|
|
x = self.conv1(x, edge_index)
|
|
x = F.elu(self.norm1(x))
|
|
|
|
x = self.conv2(x, edge_index)
|
|
x = F.elu(self.norm2(x))
|
|
return x
|
|
|
|
class GinkaGCNEncoder(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
class GinkaGCNDecoder(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
class GinkaGCNModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__() |