ginka-generator/ginka/generator/unet.py
2025-05-11 23:50:08 +08:00

189 lines
6.8 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from shared.attention import ChannelAttention
from ..common.common import GCNBlock, TransformerGCNBlock, DoubleConvBlock, ConvFusionModule
from ..common.cond import ConditionInjector
class GinkaTransformerEncoder(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
super().__init__()
in_dim = in_dim // token_size
hidden_dim = hidden_dim // token_size
out_dim = out_dim // token_size
self.embedding = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.LayerNorm(hidden_dim)
)
self.pos_embedding = nn.Parameter(torch.randn(1, token_size, hidden_dim))
self.transformer = nn.TransformerEncoder(
nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=ff_dim, batch_first=True),
num_layers=num_layers
)
self.fc = nn.Sequential(
nn.Linear(hidden_dim, out_dim),
nn.LayerNorm(out_dim)
)
def forward(self, x):
# 输入 [B, L, in_dim]
# 输出 [B, L, out_dim]
x = self.embedding(x) # [B, L, hidden_dim]
x = x + self.pos_embedding # [B, L, hidden_dim]
x = self.transformer(x) # [B, L, hidden_dim]
x = self.fc(x) # [B, L, out_dim]
return x
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, attn=True):
super().__init__()
self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
# self.conv = nn.Sequential(
# nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
# nn.InstanceNorm2d(out_ch),
# nn.ELU(),
# nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
# nn.InstanceNorm2d(out_ch),
# )
# if attn:
# self.conv.append(ChannelAttention(out_ch))
# self.conv.append(nn.ELU())
def forward(self, x):
return self.conv(x)
class FusionModule(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate')
def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)
x = self.conv(x)
return x
class GinkaUNetInput(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaEncoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.pool = nn.MaxPool2d(2)
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, cond):
x = self.pool(x)
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaUpSample(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Sequential(
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
nn.InstanceNorm2d(out_ch),
nn.GELU(),
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
nn.InstanceNorm2d(out_ch),
nn.GELU()
)
def forward(self, x):
return self.conv(x)
class GinkaDecoder(nn.Module):
def __init__(self, in_ch, out_ch, w, h):
super().__init__()
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
self.fusion = nn.Conv2d(in_ch, in_ch, 1)
self.conv = ConvFusionModule(in_ch, out_ch, out_ch, w, h)
self.inject = ConditionInjector(256, out_ch)
def forward(self, x, feat, cond):
x = self.upsample(x)
x = torch.cat([x, feat], dim=1)
x = self.fusion(x)
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaBottleneck(nn.Module):
def __init__(self, module_ch, w, h):
super().__init__()
# self.transformer = GinkaTransformerEncoder(
# in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
# token_size=16, ff_dim=1024, num_layers=4
# )
# self.gcn = TransformerGCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
# self.fusion = nn.Conv2d(module_ch*3, module_ch, 1)
self.conv = ConvFusionModule(module_ch, module_ch, module_ch, w, h)
self.inject = ConditionInjector(256, module_ch)
def forward(self, x, cond):
# x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
# x1 = self.transformer(x1)
# x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
x = self.conv(x)
x = self.inject(x, cond)
return x
class GinkaEncoderPath(nn.Module):
def __init__(self, in_ch, base_ch):
super().__init__()
self.down1 = GinkaUNetInput(in_ch, base_ch, 32, 32)
self.down2 = GinkaEncoder(base_ch, base_ch*2, 16, 16)
self.down3 = GinkaEncoder(base_ch*2, base_ch*4, 8, 8)
self.down4 = GinkaEncoder(base_ch*4, base_ch*8, 4, 4)
def forward(self, x, cond):
x1 = self.down1(x, cond) # [B, 64, 32, 32]
x2 = self.down2(x1, cond) # [B, 128, 16, 16]
x3 = self.down3(x2, cond) # [B, 256, 8, 8]
x4 = self.down4(x3, cond) # [B, 512, 4, 4]
return x1, x2, x3, x4
class GinkaDecoderPath(nn.Module):
def __init__(self, base_ch):
super().__init__()
self.up1 = GinkaDecoder(base_ch*8, base_ch*4, 8, 8)
self.up2 = GinkaDecoder(base_ch*4, base_ch*2, 16, 16)
self.up3 = GinkaDecoder(base_ch*2, base_ch, 32, 32)
def forward(self, x1, x2, x3, x4, cond):
x = self.up1(x4, x3, cond) # [B, 256, 8, 8]
x = self.up2(x, x2, cond) # [B, 128, 16, 16]
x = self.up3(x, x1, cond) # [B, 64, 32, 32]
return x
class GinkaUNet(nn.Module):
def __init__(self, in_ch=32, base_ch=32, out_ch=32):
"""Ginka Model UNet 部分
"""
super().__init__()
self.enc = GinkaEncoderPath(in_ch, base_ch)
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
self.dec = GinkaDecoderPath(base_ch)
self.final = ConvFusionModule(base_ch, base_ch, out_ch, 32, 32)
def forward(self, x, cond):
x1, x2, x3, x4 = self.enc(x, cond)
x4 = self.bottleneck(x4, cond) # [B, 512, 4, 4]
x = self.dec(x1, x2, x3, x4, cond)
x = self.final(x) # [B, 32, 32, 32]
return x