ginka-generator/ginka/generator/input.py
2025-05-11 23:50:08 +08:00

59 lines
2.0 KiB
Python

import torch
import torch.nn as nn
from ..common.common import ConvFusionModule
from ..common.cond import ConditionInjector
from .unet import GinkaEncoderPath, GinkaDecoderPath
class RandomInputHead(nn.Module):
def __init__(self):
super().__init__()
self.enc = GinkaEncoderPath(32, 32)
self.dec = GinkaDecoderPath(32)
self.out_conv = nn.Sequential(
nn.AdaptiveMaxPool2d((15, 15)),
nn.Conv2d(32, 64, 3, padding=0),
nn.InstanceNorm2d(64),
nn.GELU(),
nn.Conv2d(64, 32, 1),
)
def forward(self, x, cond):
x1, x2, x3, x4 = self.enc(x, cond)
x = self.dec(x1, x2, x3, x4, cond)
x = self.out_conv(x)
return x
class InputUpsample(nn.Module):
def __init__(self, in_ch, hidden_ch=64, out_ch=64):
super().__init__()
self.net = nn.Sequential(
ConvFusionModule(in_ch, hidden_ch, hidden_ch, 13, 13),
nn.Upsample(scale_factor=2, mode='nearest'), # 13x13 → 26x26
ConvFusionModule(hidden_ch, hidden_ch, hidden_ch, 26, 26),
nn.Upsample(size=(32, 32), mode='nearest'), # 26x26 → 32x32
ConvFusionModule(hidden_ch, hidden_ch, out_ch, 32, 32),
)
def forward(self, x): # [B, C, 13, 13]
x = self.net(x) # [B, C, 32, 32]
return x
class GinkaInput(nn.Module):
def __init__(self, in_ch=32, out_ch=64, in_size=(13, 13), out_size=(32, 32)):
super().__init__()
self.out_size = out_size
self.upsample = InputUpsample(in_ch, in_ch*2, out_ch)
self.enc = ConvFusionModule(out_ch, out_ch*2, out_ch, out_size[0], out_size[1])
self.inject1 = ConditionInjector(256, out_ch)
self.inject2 = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x = self.upsample(x)
x = self.inject1(x, cond)
x = self.enc(x)
x = self.inject2(x, cond)
return x