mirror of
https://github.com/unanmed/ginka-generator.git
synced 2026-05-17 23:21:20 +08:00
192 lines
6.5 KiB
Python
192 lines
6.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from shared.attention import ChannelAttention
|
|
from .common import GCNBlock, DoubleConvBlock
|
|
|
|
class GinkaTransformerEncoder(nn.Module):
|
|
def __init__(self, in_dim, hidden_dim, out_dim, token_size, ff_dim, num_heads=8, num_layers=6):
|
|
super().__init__()
|
|
in_dim = in_dim // token_size
|
|
hidden_dim = hidden_dim // token_size
|
|
out_dim = out_dim // token_size
|
|
self.embedding = nn.Sequential(
|
|
nn.Linear(in_dim, hidden_dim),
|
|
nn.LayerNorm(hidden_dim)
|
|
)
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, token_size, hidden_dim))
|
|
self.transformer = nn.TransformerEncoder(
|
|
nn.TransformerEncoderLayer(hidden_dim, num_heads, dim_feedforward=ff_dim, batch_first=True),
|
|
num_layers=num_layers
|
|
)
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(hidden_dim, out_dim),
|
|
nn.LayerNorm(out_dim)
|
|
)
|
|
|
|
def forward(self, x):
|
|
# 输入 [B, L, in_dim]
|
|
# 输出 [B, L, out_dim]
|
|
x = self.embedding(x) # [B, L, hidden_dim]
|
|
x = x + self.pos_embedding # [B, L, hidden_dim]
|
|
x = self.transformer(x) # [B, L, hidden_dim]
|
|
x = self.fc(x) # [B, L, out_dim]
|
|
return x
|
|
|
|
class ConvBlock(nn.Module):
|
|
def __init__(self, in_ch, out_ch, attn=True):
|
|
super().__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
|
nn.InstanceNorm2d(out_ch),
|
|
nn.ELU(),
|
|
nn.Conv2d(out_ch, out_ch, 3, padding=1, padding_mode='replicate'),
|
|
nn.InstanceNorm2d(out_ch),
|
|
)
|
|
if attn:
|
|
self.conv.append(ChannelAttention(out_ch))
|
|
self.conv.append(nn.ELU())
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
class FusionModule(nn.Module):
|
|
def __init__(self, in_ch, out_ch):
|
|
super().__init__()
|
|
self.conv = DoubleConvBlock([in_ch, out_ch, out_ch])
|
|
|
|
def forward(self, x1, x2):
|
|
x = torch.cat([x1, x2], dim=1)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
class GinkaEncoder(nn.Module):
|
|
"""编码器(下采样)部分"""
|
|
def __init__(self, in_ch, out_ch):
|
|
super().__init__()
|
|
self.conv = ConvBlock(in_ch, out_ch)
|
|
self.pool = nn.MaxPool2d(2)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.pool(x)
|
|
return x
|
|
|
|
class GinkaGCNFusedEncoder(nn.Module):
|
|
def __init__(self, in_ch, out_ch, w, h):
|
|
super().__init__()
|
|
self.conv = ConvBlock(in_ch, out_ch)
|
|
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
|
self.pool = nn.MaxPool2d(2)
|
|
self.fusion = FusionModule(out_ch*2, out_ch)
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = self.pool(x)
|
|
x2 = self.gcn(x)
|
|
x = self.fusion(x, x2)
|
|
return x
|
|
|
|
class GinkaUpSample(nn.Module):
|
|
def __init__(self, in_ch, out_ch):
|
|
super().__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2),
|
|
nn.InstanceNorm2d(out_ch),
|
|
nn.ELU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
class GinkaDecoder(nn.Module):
|
|
"""解码器(上采样)部分"""
|
|
def __init__(self, in_ch, out_ch):
|
|
super().__init__()
|
|
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
|
self.conv = ConvBlock(in_ch, out_ch)
|
|
|
|
def forward(self, x, feat):
|
|
x = self.upsample(x)
|
|
x = torch.cat([x, feat], dim=1)
|
|
x = self.conv(x)
|
|
return x
|
|
|
|
class GinkaGCNFusedDecoder(nn.Module):
|
|
def __init__(self, in_ch, out_ch, w, h):
|
|
super().__init__()
|
|
self.upsample = GinkaUpSample(in_ch, in_ch // 2)
|
|
self.conv = ConvBlock(in_ch, out_ch)
|
|
self.gcn = GCNBlock(out_ch, out_ch*2, out_ch, w, h)
|
|
self.fusion = FusionModule(out_ch*2, out_ch)
|
|
|
|
def forward(self, x, feat):
|
|
x = self.upsample(x)
|
|
x = torch.cat([x, feat], dim=1)
|
|
x = self.conv(x)
|
|
x2 = self.gcn(x)
|
|
x = self.fusion(x, x2)
|
|
return x
|
|
|
|
class GinkaBottleneck(nn.Module):
|
|
def __init__(self, module_ch, w, h):
|
|
super().__init__()
|
|
self.transformer = GinkaTransformerEncoder(
|
|
in_dim=module_ch*w*h, hidden_dim=module_ch*w*h, out_dim=module_ch*w*h,
|
|
token_size=16, ff_dim=1024, num_layers=4
|
|
)
|
|
self.gcn = GCNBlock(module_ch, module_ch*2, module_ch, 4, 4)
|
|
self.fusion = FusionModule(module_ch*2, module_ch)
|
|
|
|
def forward(self, x):
|
|
B = x.size(0)
|
|
|
|
x1 = x.view(B, 512, 16).permute(0, 2, 1) # [B, 16, in_ch]
|
|
x1 = self.transformer(x1)
|
|
x1 = x1.permute(0, 2, 1).view(B, 512, 4, 4) # [B, out_ch, 4, 4]
|
|
x2 = self.gcn(x)
|
|
|
|
x = self.fusion(x1, x2)
|
|
|
|
return x
|
|
|
|
class GinkaUNet(nn.Module):
|
|
def __init__(self, in_ch=32, base_ch=64, out_ch=32):
|
|
"""Ginka Model UNet 部分
|
|
"""
|
|
super().__init__()
|
|
# self.input = GinkaTransformerEncoder(
|
|
# in_dim=feat_dim, hidden_dim=feat_dim*2, out_dim=2*32*32, # 自动除以 token_size
|
|
# token_size=4, ff_dim=feat_dim*2, num_layers=4
|
|
# )
|
|
self.down1 = ConvBlock(in_ch, base_ch)
|
|
self.down2 = GinkaGCNFusedEncoder(base_ch, base_ch*2, 16, 16)
|
|
self.down3 = GinkaGCNFusedEncoder(base_ch*2, base_ch*4, 8, 8)
|
|
self.down4 = GinkaEncoder(base_ch*4, base_ch*8)
|
|
self.bottleneck = GinkaBottleneck(base_ch*8, 4, 4)
|
|
|
|
self.up1 = GinkaGCNFusedDecoder(base_ch*8, base_ch*4, 8, 8)
|
|
self.up2 = GinkaGCNFusedDecoder(base_ch*4, base_ch*2, 16, 16)
|
|
self.up3 = GinkaGCNFusedDecoder(base_ch*2, base_ch, 32, 32)
|
|
|
|
self.final = nn.Sequential(
|
|
nn.Conv2d(base_ch, out_ch, 1),
|
|
nn.InstanceNorm2d(out_ch),
|
|
nn.ELU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x1 = self.down1(x) # [B, 64, 32, 32]
|
|
x2 = self.down2(x1) # [B, 128, 16, 16]
|
|
x3 = self.down3(x2) # [B, 256, 8, 8]
|
|
x4 = self.down4(x3) # [B, 512, 4, 4]
|
|
x4 = self.bottleneck(x4) # [B, 512, 4, 4]
|
|
|
|
# 上采样
|
|
x = self.up1(x4, x3) # [B, 256, 8, 8]
|
|
x = self.up2(x, x2) # [B, 128, 16, 16]
|
|
x = self.up3(x, x1) # [B, 64, 32, 32]
|
|
x = self.final(x) # [B, 32, 32, 32]
|
|
|
|
return x
|