mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 15:01:10 +08:00
45 lines
1.5 KiB
Python
45 lines
1.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class GinkaInput(nn.Module):
|
|
def __init__(self, feat_dim=1024, out_ch=1, size=(32, 32)):
|
|
super().__init__()
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(feat_dim, size[0] * size[1] * out_ch),
|
|
nn.Unflatten(1, (out_ch, *size))
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.fc(x)
|
|
return x
|
|
|
|
class FeatureEncoder(nn.Module):
|
|
def __init__(self, feat_dim, size, mid_ch, out_ch):
|
|
super().__init__()
|
|
self.encode = nn.Sequential(
|
|
nn.Linear(feat_dim, mid_ch * size * size),
|
|
nn.Unflatten(1, (mid_ch, size, size)),
|
|
nn.Conv2d(mid_ch, out_ch, 1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.encode(x)
|
|
return x
|
|
|
|
class GinkaFeatureInput(nn.Module):
|
|
def __init__(self, feat_dim=1024, mid_ch=1, out_ch=64):
|
|
super().__init__()
|
|
self.encode1 = FeatureEncoder(feat_dim, 32, mid_ch, out_ch)
|
|
self.encode2 = FeatureEncoder(feat_dim, 16, mid_ch * 2, out_ch * 2)
|
|
self.encode3 = FeatureEncoder(feat_dim, 8, mid_ch * 4, out_ch * 4)
|
|
self.encode4 = FeatureEncoder(feat_dim, 4, mid_ch * 8, out_ch * 8)
|
|
self.encode5 = FeatureEncoder(feat_dim, 2, mid_ch * 16, out_ch * 16)
|
|
|
|
def forward(self, x):
|
|
x1 = self.encode1(x)
|
|
x2 = self.encode2(x)
|
|
x3 = self.encode3(x)
|
|
x4 = self.encode4(x)
|
|
x5 = self.encode5(x)
|
|
return x1, x2, x3, x4, x5
|