mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 23:21:20 +08:00
189 lines
6.8 KiB
Python
189 lines
6.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from shared.attention import ChannelAttention
|
|
from ..common.common import GCNBlock, TransformerGCNBlock, DoubleConvBlock, ConvFusionModule
|
|
from ..common.cond import ConditionInjector
|
|
|
|
class GinkaTransformerEncoder(nn.Module):
|
|
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
|
|
super().__init__()
|
|
in_dim = in_dim // token_size
|
|
hidden_dim = hidden_dim // token_size
|
|
out_dim = out_dim // token_size
|
|
self.embedding = nn.Sequential(
|
|
nn.Linear(in_dim, hidden_dim),
|
|
nn.LayerNorm(hidden_dim)
|
|
)
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, token_size, hidden_dim))
|
|
self.transformer = nn.TransformerEncoder(
|
|
nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=ff_dim, batch_first=True),
|
|
num_layers=num_layers
|
|
)
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(hidden_dim, out_dim),
|
|
nn.LayerNorm(out_dim)
|
|
)
|
|
|
|
def forward(self, x):
|
|
# 输入 [B, L, in_dim]
|
|
# 输出 [B, L, out_dim]
|
|
x = self.embedding(x) # [B, L, hidden_dim]
|
|
x = x + self.pos_embedding # [B, L, hidden_dim]
|
|
x = self.transformer(x) # [B, L, hidden_dim]
|
|
x = self.fc(x) # [B, L, out_dim]
|
|
return x
|
|
|
|
class ConvBlock(nn.Module):
|
|
def __init__(self, in_ch, out_ch, attn=True):
|
|
super().__init__()
|
|
self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
|
|
# self.conv = nn.Sequential(
|
|
# nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
|
# nn.InstanceNorm2d(out_ch),
|
|
# nn.ELU(),
|
|
# nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
|
# nn.InstanceNorm2d(out_ch),
|
|
# )
|
|
# if attn:
|
|
# self.conv.append(ChannelAttention(out_ch))
|
|
# self.conv.append(nn.ELU())
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
class FusionModule(nn.Module):
|
|
def __init__(self, in_ch, out_ch):
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate')
|
|
|
|
def forward(self, x1, x2):
|
|
x = torch.cat([x1, x2], dim=1)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
class GinkaUNetInput(nn.Module):
|
|
def __init__(self, in_ch, out_ch, w, h):
|
|
super().__init__()
|
|
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
|
|
self.inject = ConditionInjector(256, out_ch)
|
|
|
|
def forward(self, x, cond):
|
|
x = self.conv(x)
|
|
x = self.inject(x, cond)
|
|
return x
|
|
|
|
class GinkaEncoder(nn.Module):
|
|
def __init__(self, in_ch, out_ch, w, h):
|
|
super().__init__()
|
|
self.pool = nn.MaxPool2d(2)
|
|
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
|
|
self.inject = ConditionInjector(256, out_ch)
|
|
|
|
def forward(self, x, cond):
|
|
x = self.pool(x)
|
|
x = self.conv(x)
|
|
x = self.inject(x, cond)
|
|
return x
|
|
|
|
class GinkaUpSample(nn.Module):
|
|
def __init__(self, in_ch, out_ch):
|
|
super().__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
|
|
nn.InstanceNorm2d(out_ch),
|
|
nn.GELU(),
|
|
|
|
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
|
nn.InstanceNorm2d(out_ch),
|
|
nn.GELU()
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
class GinkaDecoder(nn.Module):
|
|
def __init__(self, in_ch, out_ch, w, h):
|
|
super().__init__()
|
|
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
|
self.fusion = nn.Conv2d(in_ch, in_ch, 1)
|
|
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
|
|
self.inject = ConditionInjector(256, out_ch)
|
|
|
|
def forward(self, x, feat, cond):
|
|
x = self.upsample(x)
|
|
x = torch.cat([x, feat], dim=1)
|
|
x = self.fusion(x)
|
|
x = self.conv(x)
|
|
x = self.inject(x, cond)
|
|
return x
|
|
|
|
class GinkaBottleneck(nn.Module):
|
|
def __init__(self, module_ch, w, h):
|
|
super().__init__()
|
|
# self.transformer = GinkaTransformerEncoder(
|
|
# in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
|
|
# token_size=16, ff_dim=1024, num_layers=4
|
|
# )
|
|
# self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
|
|
# self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
|
|
self.conv = ConvFusionModule(module_ch, module_ch, module_ch, w, h)
|
|
self.inject = ConditionInjector(256, module_ch)
|
|
|
|
def forward(self, x, cond):
|
|
# x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
|
|
# x1 = self.transformer(x1)
|
|
# x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
|
|
x = self.conv(x)
|
|
x = self.inject(x, cond)
|
|
return x
|
|
|
|
class GinkaEncoderPath(nn.Module):
|
|
def __init__(self, in_ch, base_ch):
|
|
super().__init__()
|
|
self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32)
|
|
self.down2 = GinkaEncoder(base_ch, base_ch*2, 16, 16)
|
|
self.down3 = GinkaEncoder(base_ch*2, base_ch*4, 8, 8)
|
|
self.down4 = GinkaEncoder(base_ch*4, base_ch*8, 4, 4)
|
|
|
|
def forward(self, x, cond):
|
|
x1 = self.down1(x, cond) # [B, 64, 32, 32]
|
|
x2 = self.down2(x1, cond) # [B, 128, 16, 16]
|
|
x3 = self.down3(x2, cond) # [B, 256, 8, 8]
|
|
x4 = self.down4(x3, cond) # [B, 512, 4, 4]
|
|
|
|
return x1, x2, x3, x4
|
|
|
|
class GinkaDecoderPath(nn.Module):
|
|
def __init__(self, base_ch):
|
|
super().__init__()
|
|
self.up1 = GinkaDecoder(base_ch*8, base_ch*4, 8, 8)
|
|
self.up2 = GinkaDecoder(base_ch*4, base_ch*2, 16, 16)
|
|
self.up3 = GinkaDecoder(base_ch*2, base_ch, 32, 32)
|
|
|
|
def forward(self, x1, x2, x3, x4, cond):
|
|
x = self.up1(x4, x3, cond) # [B, 256, 8, 8]
|
|
x = self.up2(x, x2, cond) # [B, 128, 16, 16]
|
|
x = self.up3(x, x1, cond) # [B, 64, 32, 32]
|
|
return x
|
|
|
|
class GinkaUNet(nn.Module):
|
|
def __init__(self, in_ch=32, base_ch=32, out_ch=32):
|
|
"""Ginka Model UNet 部分
|
|
"""
|
|
super().__init__()
|
|
self.enc = GinkaEncoderPath(in_ch, base_ch)
|
|
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
|
|
self.dec = GinkaDecoderPath(base_ch)
|
|
|
|
self.final = ConvFusionModule(base_ch, base_ch, out_ch, 32, 32)
|
|
|
|
def forward(self, x, cond):
|
|
x1, x2, x3, x4 = self.enc(x, cond)
|
|
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
|
|
x = self.dec(x1, x2, x3, x4, cond)
|
|
|
|
x = self.final(x) # [B, 32, 32, 32]
|
|
|
|
return x
|