ginka-generator/ginka/model/input.py

45 lines
1.5 KiB
Python

import torch
import torch.nn as nn
class GinkaInput(nn.Module):
def __init__(self, feat_dim=1024, out_ch=1, size=(32, 32)):
super().__init__()
self.fc = nn.Sequential(
nn.Linear(feat_dim, size[0] * size[1] * out_ch),
nn.Unflatten(1, (out_ch, *size))
)
def forward(self, x):
x = self.fc(x)
return x
class FeatureEncoder(nn.Module):
def __init__(self, feat_dim, size, mid_ch, out_ch):
super().__init__()
self.encode = nn.Sequential(
nn.Linear(feat_dim, mid_ch * size * size),
nn.Unflatten(1, (mid_ch, size, size)),
nn.Conv2d(mid_ch, out_ch, 1)
)
def forward(self, x):
x = self.encode(x)
return x
class GinkaFeatureInput(nn.Module):
def __init__(self, feat_dim=1024, mid_ch=1, out_ch=64):
super().__init__()
self.encode1 = FeatureEncoder(feat_dim, 32, mid_ch, out_ch)
self.encode2 = FeatureEncoder(feat_dim, 16, mid_ch * 2, out_ch * 2)
self.encode3 = FeatureEncoder(feat_dim, 8, mid_ch * 4, out_ch * 4)
self.encode4 = FeatureEncoder(feat_dim, 4, mid_ch * 8, out_ch * 8)
self.encode5 = FeatureEncoder(feat_dim, 2, mid_ch * 16, out_ch * 16)
def forward(self, x):
x1 = self.encode1(x)
x2 = self.encode2(x)
x3 = self.encode3(x)
x4 = self.encode4(x)
x5 = self.encode5(x)
return x1, x2, x3, x4, x5