mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-16 06:11:11 +08:00
57 lines
1.8 KiB
Python
57 lines
1.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from ..common.common import GCNBlock, DoubleConvBlock
|
|
from ..common.cond import ConditionInjector
|
|
|
|
class RandomInputHead(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.conv = DoubleConvBlock([32, 64, 128])
|
|
self.gcn = GCNBlock(32, 128, 128, 32, 32)
|
|
self.fusion = nn.Sequential(
|
|
nn.Conv2d(256, 256, 3, padding=1, padding_mode='replicate'),
|
|
nn.InstanceNorm2d(256),
|
|
nn.ELU(),
|
|
)
|
|
self.out_conv = nn.Sequential(
|
|
nn.Conv2d(256, 128, 3, padding=1, padding_mode='replicate'),
|
|
nn.InstanceNorm2d(128),
|
|
nn.ELU(),
|
|
|
|
nn.AdaptiveMaxPool2d((13, 13)),
|
|
nn.Conv2d(128, 32, 1),
|
|
)
|
|
self.inject = ConditionInjector(256, 256)
|
|
|
|
def forward(self, x, cond):
|
|
x_cnn = self.conv(x)
|
|
x_gcn = self.gcn(x)
|
|
x = torch.cat([x_cnn, x_gcn], dim=1)
|
|
x = self.fusion(x)
|
|
x = self.inject(x, cond)
|
|
x = self.out_conv(x)
|
|
return x
|
|
|
|
class GinkaInput(nn.Module):
|
|
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
|
|
super().__init__()
|
|
self.out_size = out_size
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(in_size[0] * in_size[1], out_size[0] * out_size[1]),
|
|
nn.LayerNorm(out_size[0] * out_size[1]),
|
|
nn.ELU()
|
|
)
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
|
nn.InstanceNorm2d(out_ch),
|
|
nn.ELU()
|
|
)
|
|
|
|
def forward(self, x):
|
|
B, C, H, W = x.shape
|
|
x = x.view(B, C, H * W)
|
|
x = self.fc(x)
|
|
x = x.view(B, C, self.out_size[0], self.out_size[1])
|
|
x = self.conv(x)
|
|
return x
|